dgl.out_subgraph

dgl.out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True)[source]

Return the subgraph induced on the outbound edges of all the edge types of the given nodes.

An out subgraph is equivalent to creating a new graph using the outcoming edges of the given nodes. In addition to extracting the subgraph, DGL also copies the features of the extracted nodes and edges to the resulting graph. The copy is lazy and incurs data movement only when needed.

If the graph is heterogeneous, DGL extracts a subgraph per relation and composes them as the resulting graph. Thus, the resulting graph has the same set of relations as the input one.

Parameters
  • graph (DGLGraph) – The input graph.

  • nodes (nodes or dict[str, nodes]) –

    The nodes to form the subgraph. The allowed nodes formats are:

    • Int Tensor: Each element is a node ID. The tensor must have the same device type and ID data type as the graph’s.

    • iterable[int]: Each element is a node ID.

    If the graph is homogeneous, one can directly pass the above formats. Otherwise, the argument must be a dictionary with keys being node types and values being the node IDs in the above formats.

  • relabel_nodes (bool, optional) – If True, it will remove the isolated nodes and relabel the rest nodes in the extracted subgraph.

  • store_ids (bool, optional) – If True, it will store the raw IDs of the extracted edges in the edata of the resulting graph under name dgl.EID; if relabel_nodes is True, it will also store the raw IDs of the extracted nodes in the ndata of the resulting graph under name dgl.NID.

Returns

The subgraph.

Return type

DGLGraph

Notes

This function discards the batch information. Please use dgl.DGLGraph.set_batch_num_nodes() and dgl.DGLGraph.set_batch_num_edges() on the transformed graph to maintain the information.

Examples

The following example uses PyTorch backend.

>>> import dgl
>>> import torch

Extract a subgraph from a homogeneous graph.

>>> g = dgl.graph(([0, 1, 2, 3, 4], [1, 2, 3, 4, 0]))  # 5-node cycle
>>> g.edata['w'] = torch.arange(10).view(5, 2)
>>> sg = dgl.out_subgraph(g, [2, 0])
>>> sg
Graph(num_nodes=5, num_edges=2,
      ndata_schemes={}
      edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),
                     '_ID': Scheme(shape=(), dtype=torch.int64)})
>>> sg.edges()
(tensor([2, 0]), tensor([3, 1]))
>>> sg.edata[dgl.EID]  # original edge IDs
tensor([2, 0])
>>> sg.edata['w']  # also extract the features
tensor([[4, 5],
        [0, 1]])

Extract a subgraph with node labeling.

>>> sg = dgl.out_subgraph(g, [2, 0], relabel_nodes=True)
>>> sg
Graph(num_nodes=4, num_edges=2,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),
                     '_ID': Scheme(shape=(), dtype=torch.int64)})
>>> sg.edges()
(tensor([2, 0]), tensor([3, 1]))
>>> sg.edata[dgl.EID]  # original edge IDs
tensor([2, 0])
>>> sg.ndata[dgl.NID]  # original node IDs
tensor([0, 1, 2, 3])

Extract a subgraph from a heterogeneous graph.

>>> g = dgl.heterograph({
...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),
...     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])})
>>> sub_g = g.out_subgraph({'user': [1]})
>>> sub_g
Graph(num_nodes={'game': 3, 'user': 3},
      num_edges={('user', 'plays', 'game'): 2, ('user', 'follows', 'user'): 2},
      metagraph=[('user', 'game', 'plays'), ('user', 'user', 'follows')])

See also

in_subgraph()