dgl.batch

dgl.batch(graphs, ndata='__ALL__', edata='__ALL__', *, node_attrs=None, edge_attrs=None)[source]

Batch a collection of DGLGraph s into one graph for more efficient graph computation.

Each input graph becomes one disjoint component of the batched graph. The nodes and edges are relabeled to be disjoint segments:

Original node ID

0 ~ N_0

0 ~ N_1

0 ~ N_k

New node ID

0 ~ N_0

N_0+1 ~ N_0+N_1+1

1+sum_{i=0}^{k-1} N_i ~ 1+sum_{i=0}^k N_i

Because of this, many of the computations on a batched graph are the same as if performed on each graph individually, but become much more efficient since they can be parallelized easily. This makes dgl.batch very useful for tasks dealing with many graph samples such as graph classification tasks.

For heterograph inputs, they must share the same set of relations (i.e., node types and edge types) and the function will perform batching on each relation one by one. Thus, the result is also a heterograph and has the same set of relations as the inputs.

The numbers of nodes and edges of the input graphs are accessible via the DGLGraph.batch_num_nodes() and DGLGraph.batch_num_edges() attributes of the resulting graph. For homogeneous graphs, they are 1D integer tensors, with each element being the number of nodes/edges of the corresponding input graph. For heterographs, they are dictionaries of 1D integer tensors, with node type or edge type as the keys.

The function supports batching batched graphs. The batch size of the result graph is the sum of the batch sizes of all the input graphs.

By default, node/edge features are batched by concatenating the feature tensors of all input graphs. This thus requires features of the same name to have the same data type and feature size. One can pass None to the ndata or edata argument to prevent feature batching, or pass a list of strings to specify which features to batch.

To unbatch the graph back to a list, use the dgl.unbatch() function.

Parameters
Returns

Batched graph.

Return type

DGLGraph

Examples

Batch homogeneous graphs

>>> import dgl
>>> import torch as th
>>> # 4 nodes, 3 edges
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
>>> # 3 nodes, 4 edges
>>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
>>> bg = dgl.batch([g1, g2])
>>> bg
Graph(num_nodes=7, num_edges=7,
      ndata_schemes={}
      edata_schemes={})
>>> bg.batch_size
2
>>> bg.batch_num_nodes()
tensor([4, 3])
>>> bg.batch_num_edges()
tensor([3, 4])
>>> bg.edges()
(tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))

Batch batched graphs

>>> bbg = dgl.batch([bg, bg])
>>> bbg.batch_size
4
>>> bbg.batch_num_nodes()
tensor([4, 3, 4, 3])
>>> bbg.batch_num_edges()
tensor([3, 4, 3, 4])

Batch graphs with feature data

>>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3)
>>> g1.edata['w'] = th.ones(g1.num_edges(), 2)
>>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3)
>>> g2.edata['w'] = th.zeros(g2.num_edges(), 2)
>>> bg = dgl.batch([g1, g2])
>>> bg.ndata['x']
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1]])
>>> bg.edata['w']
tensor([[1, 1],
        [1, 1],
        [1, 1],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0]])

Batch heterographs

>>> hg1 = dgl.heterograph({
...     ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))})
>>> hg2 = dgl.heterograph({
...     ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))})
>>> bhg = dgl.batch([hg1, hg2])
>>> bhg
Graph(num_nodes={'user': 3, 'game': 4},
      num_edges={('user', 'plays', 'game'): 5},
      metagraph=[('drug', 'game')])
>>> bhg.batch_size
2
>>> bhg.batch_num_nodes()
{'user' : tensor([2, 1]), 'game' : tensor([1, 3])}
>>> bhg.batch_num_edges()
{('user', 'plays', 'game') : tensor([2, 3])}

See also

unbatch()