dgl.slice_batchΒΆ

dgl.slice_batch(g, gid, store_ids=False)[source]ΒΆ

Get a particular graph from a batch of graphs.

Parameters
  • g (DGLGraph) – Input batched graph.

  • gid (int) – The ID of the graph to retrieve.

  • store_ids (bool) – If True, it will store the raw IDs of the extracted nodes and edges in the ndata and edata of the resulting graph under name dgl.NID and dgl.EID, respectively.

Returns

Retrieved graph.

Return type

DGLGraph

Examples

The following example uses PyTorch backend.

>>> import dgl
>>> import torch

Create a batched graph.

>>> g1 = dgl.graph(([0, 1], [2, 3]))
>>> g2 = dgl.graph(([1], [2]))
>>> bg = dgl.batch([g1, g2])

Get the second component graph.

>>> g = dgl.slice_batch(bg, 1)
>>> print(g)
Graph(num_nodes=3, num_edges=1,
      ndata_schemes={}
      edata_schemes={})