dgl.DGLGraph.apply_edgesΒΆ

DGLGraph.apply_edges(func, edges='__ALL__', etype=None)[source]ΒΆ

Update the features of the specified edges by the provided function.

Parameters
  • func (dgl.function.BuiltinFunction or callable) – The function to generate new edge features. It must be either a DGL Built-in Function or a User-defined Functions.

  • edges (edges) –

    The edges to update features on. The allowed input formats are:

    • int: A single edge ID.

    • Int Tensor: Each element is an edge ID. The tensor must have the same device type and ID data type as the graph’s.

    • iterable[int]: Each element is an edge ID.

    • (Tensor, Tensor): The node-tensors format where the i-th elements of the two tensors specify an edge.

    • (iterable[int], iterable[int]): Similar to the node-tensors format but stores edge endpoints in python iterables.

    Default value specifies all the edges in the graph.

  • etype (str or (str, str, str), optional) –

    The type name of the edges. The allowed type name formats are:

    • (str, str, str) for source node type, edge type and destination node type.

    • or one str edge type name if the name can uniquely identify a triplet format in the graph.

    Can be omitted if the graph has only one type of edges.

Notes

DGL recommends using DGL’s bulit-in function for the func argument, because DGL will invoke efficient kernels that avoids copying node features to edge features in this case.

Examples

The following example uses PyTorch backend.

>>> import dgl
>>> import torch

Homogeneous graph

>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
>>> g.ndata['h'] = torch.ones(5, 2)
>>> g.apply_edges(lambda edges: {'x' : edges.src['h'] + edges.dst['h']})
>>> g.edata['x']
tensor([[2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.]])

Use built-in function

>>> import dgl.function as fn
>>> g.apply_edges(fn.u_add_v('h', 'h', 'x'))
>>> g.edata['x']
tensor([[2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.]])

Heterogeneous graph

>>> g = dgl.heterograph({('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1])})
>>> g.edges[('user', 'plays', 'game')].data['h'] = torch.ones(4, 5)
>>> g.apply_edges(lambda edges: {'h': edges.data['h'] * 2})
>>> g.edges[('user', 'plays', 'game')].data['h']
tensor([[2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.]])

See also

apply_nodes