dgl.batch¶

dgl.
batch
(graph_list, node_attrs='__ALL__', edge_attrs='__ALL__')[source]¶ Batch a collection of
DGLGraph
and return a batchedDGLGraph
object that is independent of thegraph_list
so that one can perform message passing and readout over a batch of graphs simultaneously, the batch size of the returned graph is the length ofgraph_list
.The nodes and edges are reindexed with a new id in the batched graph with the rule below:
item Graph 1 Graph 2 … Graph k raw id 0, …, N1 0, …, N2 … …, Nk new id 0, …, N1 N1 + 1, …, N1 + N2 + 1 … …, N1 + … + Nk + k  1 To modify the features in the batched graph has no effect on the original graphs. See the examples below about how to work around.
Parameters:  graph_list (iterable) – A collection of
DGLGraph
to be batched.  node_attrs (None, str or iterable) – The node attributes to be batched. If
None
, the returnedDGLGraph
object will not have any node attributes. By default, all node attributes will be batched. Ifstr
or iterable, this should specify exactly what node attributes to be batched.  edge_attrs (None, str or iterable, optional) – Same as for the case of
node_attrs
Returns: One single batched graph.
Return type: Examples
Create two
DGLGraph
objects. Instantiation:>>> import dgl >>> import torch as th >>> g1 = dgl.DGLGraph() >>> g1.add_nodes(2) # Add 2 nodes >>> g1.add_edge(0, 1) # Add edge 0 > 1 >>> g1.ndata['hv'] = th.tensor([[0.], [1.]]) # Initialize node features >>> g1.edata['he'] = th.tensor([[0.]]) # Initialize edge features >>> g2 = dgl.DGLGraph() >>> g2.add_nodes(3) # Add 3 nodes >>> g2.add_edges([0, 2], [1, 1]) # Add edges 0 > 1, 2 > 1 >>> g2.ndata['hv'] = th.tensor([[2.], [3.], [4.]]) # Initialize node features >>> g2.edata['he'] = th.tensor([[1.], [2.]]) # Initialize edge features
Merge two
DGLGraph
objects into oneDGLGraph
object. When merging a list of graphs, we can choose to include only a subset of the attributes.>>> bg = dgl.batch([g1, g2], edge_attrs=None) >>> bg.edata {}
Below one can see that the nodes are reindexed. The edges are reindexed in the same way.
>>> bg.nodes() tensor([0, 1, 2, 3, 4]) >>> bg.ndata['hv'] tensor([[0.], [1.], [2.], [3.], [4.]])
Property: We can still get a brief summary of the graphs that constitute the batched graph.
>>> bg.batch_size 2 >>> bg.batch_num_nodes [2, 3] >>> bg.batch_num_edges [1, 2]
Readout: Another common demand for graph neural networks is graph readout, which is a function that takes in the node attributes and/or edge attributes for a graph and outputs a vector summarizing the information in the graph. DGL also supports performing readout for a batch of graphs at once. Below we take the builtin readout function
sum_nodes()
as an example, which sums over a particular kind of node attribute for each graph.>>> dgl.sum_nodes(bg, 'hv') # Sum the node attribute 'hv' for each graph. tensor([[1.], # 0 + 1 [9.]]) # 2 + 3 + 4
Message passing: For message passing and related operations, batched
DGLGraph
acts exactly the same as a singleDGLGraph
with batch size 1.Update Attributes: Updating the attributes of the batched graph has no effect on the original graphs.
>>> bg.edata['he'] = th.zeros(3, 2) >>> g2.edata['he'] tensor([[1.], [2.]])}
Instead, we can decompose the batched graph back into a list of graphs and use them to replace the original graphs.
>>> g1, g2 = dgl.unbatch(bg) # returns a list of DGLGraph objects >>> g2.edata['he'] tensor([[0., 0.], [0., 0.]])}
See also
 graph_list (iterable) – A collection of