dgl.batch

dgl.batch(graph_list, node_attrs='__ALL__', edge_attrs='__ALL__')[source]

Batch a collection of DGLGraph and return a batched DGLGraph object that is independent of the graph_list so that one can perform message passing and readout over a batch of graphs simultaneously, the batch size of the returned graph is the length of graph_list.

The nodes and edges are re-indexed with a new id in the batched graph with the rule below:

item Graph 1 Graph 2 Graph k
raw id 0, …, N1 0, …, N2 …, Nk
new id 0, …, N1 N1 + 1, …, N1 + N2 + 1 …, N1 + … + Nk + k - 1

To modify the features in the batched graph has no effect on the original graphs. See the examples below about how to work around.

Parameters:
  • graph_list (iterable) – A collection of DGLGraph to be batched.
  • node_attrs (None, str or iterable) – The node attributes to be batched. If None, the returned DGLGraph object will not have any node attributes. By default, all node attributes will be batched. If str or iterable, this should specify exactly what node attributes to be batched.
  • edge_attrs (None, str or iterable, optional) – Same as for the case of node_attrs
Returns:

One single batched graph.

Return type:

DGLGraph

Examples

Create two DGLGraph objects. Instantiation:

>>> import dgl
>>> import torch as th
>>> g1 = dgl.DGLGraph()
>>> g1.add_nodes(2)                                # Add 2 nodes
>>> g1.add_edge(0, 1)                              # Add edge 0 -> 1
>>> g1.ndata['hv'] = th.tensor([[0.], [1.]])       # Initialize node features
>>> g1.edata['he'] = th.tensor([[0.]])             # Initialize edge features
>>> g2 = dgl.DGLGraph()
>>> g2.add_nodes(3)                                # Add 3 nodes
>>> g2.add_edges([0, 2], [1, 1])                   # Add edges 0 -> 1, 2 -> 1
>>> g2.ndata['hv'] = th.tensor([[2.], [3.], [4.]]) # Initialize node features
>>> g2.edata['he'] = th.tensor([[1.], [2.]])       # Initialize edge features

Merge two DGLGraph objects into one DGLGraph object. When merging a list of graphs, we can choose to include only a subset of the attributes.

>>> bg = dgl.batch([g1, g2], edge_attrs=None)
>>> bg.edata
{}

Below one can see that the nodes are re-indexed. The edges are re-indexed in the same way.

>>> bg.nodes()
tensor([0, 1, 2, 3, 4])
>>> bg.ndata['hv']
tensor([[0.],
        [1.],
        [2.],
        [3.],
        [4.]])

Property: We can still get a brief summary of the graphs that constitute the batched graph.

>>> bg.batch_size
2
>>> bg.batch_num_nodes
[2, 3]
>>> bg.batch_num_edges
[1, 2]

Readout: Another common demand for graph neural networks is graph readout, which is a function that takes in the node attributes and/or edge attributes for a graph and outputs a vector summarizing the information in the graph. DGL also supports performing readout for a batch of graphs at once. Below we take the built-in readout function sum_nodes() as an example, which sums over a particular kind of node attribute for each graph.

>>> dgl.sum_nodes(bg, 'hv') # Sum the node attribute 'hv' for each graph.
tensor([[1.],               # 0 + 1
        [9.]])              # 2 + 3 + 4

Message passing: For message passing and related operations, batched DGLGraph acts exactly the same as a single DGLGraph with batch size 1.

Update Attributes: Updating the attributes of the batched graph has no effect on the original graphs.

>>> bg.edata['he'] = th.zeros(3, 2)
>>> g2.edata['he']
tensor([[1.],
        [2.]])}

Instead, we can decompose the batched graph back into a list of graphs and use them to replace the original graphs.

>>> g1, g2 = dgl.unbatch(bg)    # returns a list of DGLGraph objects
>>> g2.edata['he']
tensor([[0., 0.],
        [0., 0.]])}

See also

unbatch()