dgl.DGLGraph.update_all¶
-
DGLGraph.
update_all
(message_func, reduce_func, apply_node_func=None, etype=None)¶ Send messages along all the edges of the specified type and update all the nodes of the corresponding destination type.
For heterogeneous graphs with number of relation types > 1, send messages along all the edges, reduce them by type-wisely and across different types at the same time. Then, update the node features of all the nodes.
- Parameters
message_func (dgl.function.BuiltinFunction or callable) – The message function to generate messages along the edges. It must be either a DGL Built-in Function or a User-defined Functions.
reduce_func (dgl.function.BuiltinFunction or callable) – The reduce function to aggregate the messages. It must be either a DGL Built-in Function or a User-defined Functions.
apply_node_func (callable, optional) – An optional apply function to further update the node features after the message reduction. It must be a User-defined Functions.
etype (str or (str, str, str), optional) –
The type name of the edges. The allowed type name formats are:
(str, str, str)
for source node type, edge type and destination node type.or one
str
edge type name if the name can uniquely identify a triplet format in the graph.
Can be omitted if the graph has only one type of edges.
Notes
If some of the nodes in the graph has no in-edges, DGL does not invoke message and reduce functions for these nodes and fill their aggregated messages with zero. Users can control the filled values via
set_n_initializer()
. DGL still invokesapply_node_func
if provided.DGL recommends using DGL’s bulit-in function for the
message_func
and thereduce_func
arguments, because DGL will invoke efficient kernels that avoids copying node features to edge features in this case.
Examples
>>> import dgl >>> import dgl.function as fn >>> import torch
Homogeneous graph
>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4])) >>> g.ndata['x'] = torch.ones(5, 2) >>> g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'h')) >>> g.ndata['h'] tensor([[0., 0.], [1., 1.], [1., 1.], [1., 1.], [1., 1.]])
Heterogeneous graph
>>> g = dgl.heterograph({('user', 'follows', 'user'): ([0, 1, 2], [1, 2, 2])})
Update all.
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]]) >>> g['follows'].update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'), etype='follows') >>> g.nodes['user'].data['h'] tensor([[0.], [0.], [3.]])
Heterogenenous graph (number relation types > 1)
>>> g = dgl.heterograph({ ... ('user', 'follows', 'user'): ([0, 1], [1, 1]), ... ('game', 'attracts', 'user'): ([0], [1]) ... })
Update all.
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]]) >>> g.nodes['game'].data['h'] = torch.tensor([[1.]]) >>> g.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h')) >>> g.nodes['user'].data['h'] tensor([[0.], [4.]])