# dgl.DGLGraph.filter_nodes¶

DGLGraph.filter_nodes(predicate, nodes='__ALL__')[source]

Return a tensor of node IDs that satisfy the given predicate.

Parameters: predicate (callable) – A function of signature func(nodes) -> tensor. nodes are NodeBatch objects as in udf. The tensor returned should be a 1-D boolean tensor with each element indicating whether the corresponding node in the batch satisfies the predicate. nodes (int, iterable or tensor of ints) – The nodes to filter on. Default value is all the nodes. The filtered nodes. tensor

Examples

Construct a graph object for demo.

Note

Here we use pytorch syntax for demo. The general idea applies to other frameworks with minor syntax change (e.g. replace torch.tensor with mxnet.ndarray).

>>> import torch as th
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.ndata['x'] = th.tensor([[1.], [-1.], [1.]])


Define a function for filtering nodes with feature $$1$$.

>>> def has_feature_one(nodes): return (nodes.data['x'] == 1).squeeze(1)


Filter the nodes with feature $$1$$.

>>> g.filter_nodes(has_feature_one)
tensor([0, 2])