dgl.udf.EdgeBatch.batch_sizeΒΆ
-
EdgeBatch.
batch_size
()[source]ΒΆ Return the number of edges in the batch.
- Returns
- Return type
Examples
The following example uses PyTorch backend.
>>> import dgl >>> import torch
>>> # Instantiate a graph. >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> # Define a UDF that returns one for each edge. >>> def edge_udf(edges): >>> return {'h': torch.ones(edges.batch_size(), 1)}
>>> # Creates a feature 'h'. >>> g.apply_edges(edge_udf) >>> g.edata['h'] tensor([[1.], [1.], [1.]])
>>> # Use edge UDF in message passing. >>> import dgl.function as fn >>> g.update_all(edge_udf, fn.sum('h', 'h')) >>> g.ndata['h'] tensor([[1.], [2.]])