dgl.DGLGraph.multi_update_allο
- DGLGraph.multi_update_all(etype_dict, cross_reducer, apply_node_func=None)[source]ο
Send messages along all the edges, reduce them by first type-wisely then across different types, and then update the node features of all the nodes.
- Parameters:
etype_dict (dict) β
Arguments for edge-type-wise message passing. The keys are edge types while the values are message passing arguments.
The allowed key formats are:
(str, str, str)
for source node type, edge type and destination node type.or one
str
edge type name if the name can uniquely identify a triplet format in the graph.
The value must be a tuple
(message_func, reduce_func, [apply_node_func])
, where- message_funcdgl.function.BuiltinFunction or callable
The message function to generate messages along the edges. It must be either a DGL Built-in Function or a User-defined Functions.
- reduce_funcdgl.function.BuiltinFunction or callable
The reduce function to aggregate the messages. It must be either a DGL Built-in Function or a User-defined Functions.
- apply_node_funccallable, optional
An optional apply function to further update the node features after the message reduction. It must be a User-defined Functions.
cross_reducer (str or callable function) β Cross type reducer. One of
"sum"
,"min"
,"max"
,"mean"
,"stack"
or a callable function. If a callable function is provided, the input argument must be a single list of tensors containing aggregation results from each edge type, and the output of function must be a single tensor.apply_node_func (callable, optional) β An optional apply function after the messages are reduced both type-wisely and across different types. It must be a User-defined Functions.
Notes
DGL recommends using DGLβs bulit-in function for the message_func and the reduce_func in the type-wise message passing arguments, because DGL will invoke efficient kernels that avoids copying node features to edge features in this case.
Examples
>>> import dgl >>> import dgl.function as fn >>> import torch
Instantiate a heterograph.
>>> g = dgl.heterograph({ ... ('user', 'follows', 'user'): ([0, 1], [1, 1]), ... ('game', 'attracts', 'user'): ([0], [1]) ... }) >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]]) >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
Update all.
>>> g.multi_update_all( ... {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')), ... 'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))}, ... "sum") >>> g.nodes['user'].data['h'] tensor([[0.], [4.]])
User-defined cross reducer equivalent to βsumβ.
>>> def cross_sum(flist): ... return torch.sum(torch.stack(flist, dim=0), dim=0) if len(flist) > 1 else flist[0]
Use the user-defined cross reducer.
>>> g.multi_update_all( ... {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')), ... 'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))}, ... cross_sum)