dgl.slice_batch

dgl.slice_batch(g, gid, store_ids=False)[source]

Get a particular graph from a batch of graphs.

Parameters:
  • g (DGLGraph) – Input batched graph.

  • gid (int) – The ID of the graph to retrieve.

  • store_ids (bool) – If True, it will store the raw IDs of the extracted nodes and edges in the ndata and edata of the resulting graph under name dgl.NID and dgl.EID, respectively.

Returns:

Retrieved graph.

Return type:

DGLGraph

Examples

The following example uses PyTorch backend.

>>> import dgl
>>> import torch

Create a batched graph.

>>> g1 = dgl.graph(([0, 1], [2, 3]))
>>> g2 = dgl.graph(([1], [2]))
>>> bg = dgl.batch([g1, g2])

Get the second component graph.

>>> g = dgl.slice_batch(bg, 1)
>>> print(g)
Graph(num_nodes=3, num_edges=1,
      ndata_schemes={}
      edata_schemes={})