# dgl.DGLGraph.filter_edges¶

DGLGraph.filter_edges(predicate, edges='__ALL__')[source]

Return a tensor of edge IDs that satisfy the given predicate.

Parameters: predicate (callable) – A function of signature func(edges) -> tensor. edges are EdgeBatch objects as in udf. The tensor returned should be a 1-D boolean tensor with each element indicating whether the corresponding edge in the batch satisfies the predicate. edges (valid edges type) – Edges on which to apply func. See send() for valid edges type. Default value is all the edges. The filtered edges represented by their ids. tensor

Examples

Construct a graph object for demo.

Note

Here we use pytorch syntax for demo. The general idea applies to other frameworks with minor syntax change (e.g. replace torch.tensor with mxnet.ndarray).

>>> import torch as th
>>> g = dgl.DGLGraph()
>>> g.ndata['x'] = th.tensor([[1.], [-1.], [1.]])
>>> g.add_edges([0, 1, 2], [2, 2, 1])


Define a function for filtering edges whose destinations have node feature $$1$$.

>>> def has_dst_one(edges): return (edges.dst['x'] == 1).squeeze(1)


Filter the edges whose destination nodes have feature $$1$$.

>>> g.filter_edges(has_dst_one)
tensor([0, 1])