dgl.DGLGraph.multi_update_all

DGLGraph.multi_update_all(etype_dict, cross_reducer, apply_node_func=None)

Send messages along all the edges, reduce them by first type-wisely then across different types, and then update the node features of all the nodes.

Parameters
  • etype_dict (dict) –

    Arguments for edge-type-wise message passing. The keys are edge types while the values are message passing arguments.

    The allowed key formats are:

    • (str, str, str) for source node type, edge type and destination node type.

    • or one str edge type name if the name can uniquely identify a triplet format in the graph.

    The value must be a tuple (message_func, reduce_func, [apply_node_func]), where

  • cross_reducer (str or callable function) – Cross type reducer. One of "sum", "min", "max", "mean", "stack" or a callable function. If a callable function is provided, the input argument must be a single list of tensors containing aggregation results from each edge type, and the output of function must be a single tensor.

  • apply_node_func (callable, optional) – An optional apply function after the messages are reduced both type-wisely and across different types. It must be a User-defined Functions.

Notes

DGL recommends using DGL’s bulit-in function for the message_func and the reduce_func in the type-wise message passing arguments, because DGL will invoke efficient kernels that avoids copying node features to edge features in this case.

Examples

>>> import dgl
>>> import dgl.function as fn
>>> import torch

Instantiate a heterograph.

>>> g = dgl.heterograph({
...     ('user', 'follows', 'user'): ([0, 1], [1, 1]),
...     ('game', 'attracts', 'user'): ([0], [1])
... })
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])

Update all.

>>> g.multi_update_all(
...     {'follows': (fn.copy_src('h', 'm'), fn.sum('m', 'h')),
...      'attracts': (fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
... "sum")
>>> g.nodes['user'].data['h']
tensor([[0.],
        [4.]])

User-defined cross reducer equivalent to “sum”.

>>> def cross_sum(flist):
...     return torch.sum(torch.stack(flist, dim=0), dim=0) if len(flist) > 1 else flist[0]

Use the user-defined cross reducer.

>>> g.multi_update_all(
...     {'follows': (fn.copy_src('h', 'm'), fn.sum('m', 'h')),
...      'attracts': (fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
... cross_sum)