dgl.DGLHeteroGraph.filter_nodes

DGLHeteroGraph.filter_nodes(predicate, nodes='__ALL__', ntype=None)[source]

Return the IDs of the nodes with the given node type that satisfy the given predicate.

Parameters
  • predicate (callable) – A function of signature func(nodes) -> Tensor. nodes are dgl.NodeBatch objects. Its output tensor should be a 1D boolean tensor with each element indicating whether the corresponding node in the batch satisfies the predicate.

  • nodes (node ID(s), optional) –

    The node(s) for query. The allowed formats are:

    • Tensor: A 1D tensor that contains the node(s) for query, whose data type and device should be the same as the idtype and device of the graph.

    • iterable[int] : Similar to the tensor, but stores node IDs in a sequence (e.g. list, tuple, numpy.ndarray).

    By default, it considers all nodes.

  • ntype (str, optional) – The node type for query. If the graph has multiple node types, one must specify the argument. Otherwise, it can be omitted.

Returns

A 1D tensor that contains the ID(s) of the node(s) that satisfy the predicate.

Return type

Tensor

Examples

The following example uses PyTorch backend.

>>> import dgl
>>> import torch

Define a predicate function.

>>> def nodes_with_feature_one(nodes):
...     # Whether a node has feature 1
...     return (nodes.data['h'] == 1.).squeeze(1)

Filter nodes for a homogeneous graph.

>>> g = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))
>>> g.ndata['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
>>> print(g.filter_nodes(nodes_with_feature_one))
tensor([1, 2])

Filter on nodes with IDs 0 and 1

>>> print(g.filter_nodes(nodes_with_feature_one, nodes=torch.tensor([0, 1])))
tensor([1])

Filter nodes for a heterogeneous graph.

>>> g = dgl.heterograph({
...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),
...                                 torch.tensor([0, 0, 1, 1]))})
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[0.], [1.]])
>>> # Filter for 'user' nodes
>>> print(g.filter_nodes(nodes_with_feature_one, ntype='user'))
tensor([1, 2])