# dgl.DGLHeteroGraph.filter_nodes¶

DGLHeteroGraph.filter_nodes(predicate, nodes='__ALL__', ntype=None)[source]

Return a tensor of node IDs with the given node type that satisfy the given predicate.

Parameters: predicate (callable) – A function of signature func(nodes) -> tensor. nodes are NodeBatch objects as in udf. The tensor returned should be a 1-D boolean tensor with each element indicating whether the corresponding node in the batch satisfies the predicate. nodes (int, iterable or tensor of ints) – The nodes to filter on. Default value is all the nodes. ntype (str, optional) – The node type. Can be omitted if there is only one node type in the graph. (Default: None) Node ids indicating the nodes that satisfy the predicate. tensor

Examples

>>> import torch
>>> import dgl
>>> import dgl.function as fn
>>> g = dgl.graph([], 'user', 'follows', card=4)
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
>>> g.filter_nodes(lambda nodes: (nodes.data['h'] == 1.).squeeze(1), ntype='user')
tensor([1, 2])