dgl.udf.NodeBatch.batch_size

NodeBatch.batch_size()[source]

Return the number of nodes in the batch.

Return type:

int

Examples

The following example uses PyTorch backend.

>>> import dgl
>>> import torch
>>> # Instantiate a graph.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that computes the sum of the messages received for
>>> # each node and increments the result by 1.
>>> def node_udf(nodes):
>>>     return {'h': torch.ones(nodes.batch_size(), 1)
>>>         + nodes.mailbox['m'].sum(1)}
>>> # Use node UDF in message passing.
>>> import dgl.function as fn
>>> g.update_all(fn.copy_u('h', 'm'), node_udf)
>>> g.ndata['h']
tensor([[2.],
        [3.]])