dgl.DGLHeteroGraph.filter_edges

DGLHeteroGraph.filter_edges(predicate, edges='__ALL__', etype=None)[source]

Return a tensor of edge IDs with the given edge type that satisfy the given predicate.

Parameters:
  • predicate (callable) – A function of signature func(edges) -> tensor. edges are EdgeBatch objects as in udf. The tensor returned should be a 1-D boolean tensor with each element indicating whether the corresponding edge in the batch satisfies the predicate.
  • edges (valid edges type) – Edges on which to apply func. See send() for valid edges type. Default value is all the edges.
  • etype (str, optional) – The edge type. Can be omitted if there is only one edge type in the graph. (Default: None)
Returns:

Edge ids indicating the edges that satisfy the predicate.

Return type:

tensor

Examples

>>> import torch
>>> import dgl
>>> import dgl.function as fn
>>> g = dgl.graph([(0, 0), (0, 1), (1, 2), (2, 3)], 'user', 'follows')
>>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
>>> g.filter_edges(lambda edges: (edges.data['h'] == 1.).squeeze(1), etype='follows')
tensor([1, 2])