DGLGraph.prop_nodes(nodes_generator, message_func='default', reduce_func='default', apply_node_func='default')[source]

Propagate messages using graph traversal by triggering pull() on nodes.

The traversal order is specified by the nodes_generator. It generates node frontiers, which is a list or a tensor of nodes. The nodes in the same frontier will be triggered together, while nodes in different frontiers will be triggered according to the generating order.

  • node_generators (iterable, each element is a list or a tensor of node ids) – The generator of node frontiers. It specifies which nodes perform pull() at each timestep.
  • message_func (callable, optional) – Message function on the edges. The function should be an Edge UDF.
  • reduce_func (callable, optional) – Reduce function on the node. The function should be a Node UDF.
  • apply_node_func (callable, optional) – Apply function on the nodes. The function should be a Node UDF.


Create a graph for demo.


Here we use pytorch syntax for demo. The general idea applies to other frameworks with minor syntax change (e.g. replace torch.tensor with mxnet.ndarray).

>>> import torch as th
>>> g = dgl.DGLGraph()
>>> g.add_nodes(4)
>>> g.ndata['x'] = th.tensor([[1.], [2.], [3.], [4.]])
>>> g.add_edges([0, 1, 1, 2], [1, 2, 3, 3])

Prepare message function and message reduce function for demo.

>>> def send_source(edges): return {'m': edges.src['x']}
>>> g.register_message_func(send_source)
>>> def simple_reduce(nodes): return {'x': nodes.mailbox['m'].sum(1)}
>>> g.register_reduce_func(simple_reduce)

First pull messages for nodes \(1, 2\) with edges 0 -> 1 and 1 -> 2; and then pull messages for node \(3\) with edges 1 -> 3 and 2 -> 3.

>>> g.prop_nodes([[1, 2], [3]])
>>> g.ndata['x']

In the first stage, we pull messages for nodes \(1, 2\). The feature of node \(1\) is replaced by that of node \(0\), i.e. 1 The feature of node \(2\) is replaced by that of node \(1\), i.e. 2. Both of the replacement happen simultaneously.

In the second stage, we pull messages for node \(3\). The feature of node \(3\) becomes the sum of node \(1\)‘s feature and \(2\)‘s feature, i.e. 1 + 2 = 3.

See also