dgl.DGLGraph.pull

DGLGraph.pull(v, message_func='default', reduce_func='default', apply_node_func='default', inplace=False)[source]

Pull messages from the node(s)’ predecessors and then update their features.

Optionally, apply a function to update the node features after receive.

  • reduce_func will be skipped for nodes with no incoming message.
  • If all v have no incoming message, this will downgrade to an apply_nodes().
  • If some v have no incoming message, their new feature value will be calculated by the column initializer (see set_n_initializer()). The feature shapes and dtypes will be inferred.
Parameters:
  • v (int, iterable of int, or tensor) – The node(s) to be updated.
  • message_func (callable, optional) – Message function on the edges. The function should be an Edge UDF.
  • reduce_func (callable, optional) – Reduce function on the node. The function should be a Node UDF.
  • apply_node_func (callable, optional) – Apply function on the nodes. The function should be a Node UDF.
  • inplace (bool, optional) – If True, update will be done in place, but autograd will break.

Examples

Create a graph for demo.

Note

Here we use pytorch syntax for demo. The general idea applies to other frameworks with minor syntax change (e.g. replace torch.tensor with mxnet.ndarray).

>>> import torch as th
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.ndata['x'] = th.tensor([[0.], [1.], [2.]])

Use the built-in message function copy_src() for copying node features as the message.

>>> m_func = dgl.function.copy_src('x', 'm')
>>> g.register_message_func(m_func)

Use the built-int message reducing function sum(), which sums the messages received and replace the old node features with it.

>>> m_reduce_func = dgl.function.sum('m', 'x')
>>> g.register_reduce_func(m_reduce_func)

As no edges exist, nothing happens.

>>> g.pull(g.nodes())
>>> g.ndata['x']
tensor([[0.],
        [1.],
        [2.]])

Add edges 0 -> 1, 1 -> 2. Pull messages for the node \(2\).

>>> g.add_edges([0, 1], [1, 2])
>>> g.pull(2)
>>> g.ndata['x']
tensor([[0.],
        [1.],
        [1.]])

The feature of node \(2\) changes but the feature of node \(1\) remains the same as we did not pull() (and reduce) messages for it.

See also

push()