dgl.DGLHeteroGraph.filter_nodes

DGLHeteroGraph.filter_nodes(predicate, nodes='__ALL__', ntype=None)[source]

Return a tensor of node IDs with the given node type that satisfy the given predicate.

Parameters:
  • predicate (callable) – A function of signature func(nodes) -> tensor. nodes are NodeBatch objects as in udf. The tensor returned should be a 1-D boolean tensor with each element indicating whether the corresponding node in the batch satisfies the predicate.
  • nodes (int, iterable or tensor of ints) – The nodes to filter on. Default value is all the nodes.
  • ntype (str, optional) – The node type. Can be omitted if there is only one node type in the graph. (Default: None)
Returns:

Node ids indicating the nodes that satisfy the predicate.

Return type:

tensor

Examples

>>> import torch
>>> import dgl
>>> import dgl.function as fn
>>> g = dgl.graph([], 'user', 'follows', num_nodes=4)
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
>>> g.filter_nodes(lambda nodes: (nodes.data['h'] == 1.).squeeze(1), ntype='user')
tensor([1, 2])