dgl.ops

Frame-agnostic operators for message passing on graphs.

GSpMM functions

Generalized Sparse-Matrix Dense-Matrix Multiplication functions. It fuses two steps into one kernel.

  1. Computes messages by add/sub/mul/div source node and edge features, or copy node features to edges.

  2. Aggregate the messages by sum/max/min/mean as the features on destination nodes.

Our implementation supports tensors on CPU/GPU in PyTorch/MXNet/Tensorflow as input. All operators are equipped with autograd (computing the input gradients given output gradient) and broadcasting (if the feature shape of operands do not match, we first broadcast them to the same shape, then applies the binary operators). Our broadcast semantics follows NumPy, please see https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html for more details.

What do we mean by fuses is that the messages are not materialized on edges, instead we compute the result on destination nodes directly, thus saving memory cost. The space complexity of GSpMM operators is \(O(|N|D)\) where \(|N|\) refers to the number of nodes in the graph, and \(D\) refers to the feature size (\(D=\prod_{i=1}^{N}D_i\) if your feature is a multi-dimensional tensor).

The following is an example showing how GSpMM works (we use PyTorch as the backend here, you can enjoy the same convenience on other frameworks by similar usage):

>>> import dgl
>>> import torch as th
>>> import dgl.ops as F
>>> g = dgl.graph(([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]))  # 3 nodes, 6 edges
>>> x = th.ones(3, 2, requires_grad=True)
>>> x
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]], requires_grad=True)
>>> y = th.arange(1, 13).float().view(6, 2).requires_grad_()
tensor([[ 1.,  2.],
        [ 3.,  4.],
        [ 5.,  6.],
        [ 7.,  8.],
        [ 9., 10.],
        [11., 12.]], requires_grad=True)
>>> out_1 = F.u_mul_e_sum(g, x, y)
>>> out_1  # (10, 12) = ((1, 1) * (3, 4)) + ((1, 1) * (7, 8))
tensor([[ 1.,  2.],
        [10., 12.],
        [25., 28.]], grad_fn=<GSpMMBackward>)
>>> out_1.sum().backward()
>>> x.grad
tensor([[12., 15.],
        [18., 20.],
        [12., 13.]])
>>> y.grad
tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])
>>> out_2 = F.copy_u_sum(g, x)
>>> out_2
tensor([[1., 1.],
        [2., 2.],
        [3., 3.]], grad_fn=<GSpMMBackward>)
>>> out_3 = F.u_add_e_max(g, x, y)
>>> out_3
tensor([[ 2.,  3.],
        [ 8.,  9.],
        [12., 13.]], grad_fn=<GSpMMBackward>)
>>> y1 = th.rand(6, 4, 2, requires_grad=True)  # test broadcast
>>> F.u_mul_e_sum(g, x, y1).shape  # (2,), (4, 2) -> (4, 2)
torch.Size([3, 4, 2])

For all operators, the input graph could either be a homogeneous or a bipartite graph.

gspmm(g, op, reduce_op, lhs_data, rhs_data)

Generalized Sparse Matrix Multiplication interface.

u_add_e_sum(g, x, y)

Generalized SpMM function.

u_sub_e_sum(g, x, y)

Generalized SpMM function.

u_mul_e_sum(g, x, y)

Generalized SpMM function.

u_div_e_sum(g, x, y)

Generalized SpMM function.

u_add_e_max(g, x, y)

Generalized SpMM function.

u_sub_e_max(g, x, y)

Generalized SpMM function.

u_mul_e_max(g, x, y)

Generalized SpMM function.

u_div_e_max(g, x, y)

Generalized SpMM function.

u_add_e_min(g, x, y)

Generalized SpMM function.

u_sub_e_min(g, x, y)

Generalized SpMM function.

u_mul_e_min(g, x, y)

Generalized SpMM function.

u_div_e_min(g, x, y)

Generalized SpMM function.

u_add_e_mean(g, x, y)

Generalized SpMM function.

u_sub_e_mean(g, x, y)

Generalized SpMM function.

u_mul_e_mean(g, x, y)

Generalized SpMM function.

u_div_e_mean(g, x, y)

Generalized SpMM function.

copy_u_sum(g, x)

Generalized SpMM function.

copy_e_sum(g, x)

Generalized SpMM function.

copy_u_max(g, x)

Generalized SpMM function.

copy_e_max(g, x)

Generalized SpMM function.

copy_u_min(g, x)

Generalized SpMM function.

copy_e_min(g, x)

Generalized SpMM function.

copy_u_mean(g, x)

Generalized SpMM function.

copy_e_mean(g, x)

Generalized SpMM function.

GSDDMM functions

Generalized Sampled Dense-Dense Matrix Multiplication. It computes edge features by add/sub/mul/div/dot features on source/destination nodes or edges.

Like GSpMM, our implementation supports tensors on CPU/GPU in PyTorch/MXNet/Tensorflow as input. All operators are equipped with autograd and broadcasting.

The memory cost of GSDDMM is \(O(|E|D)\) where \(|E|\) refers to the number of edges in the graph while \(D\) refers to the feature size.

Note that we support dot operator, which semantically is the same as reduce the last dimension by sum to the result of mul operator. However, the dot is more memory efficient because it fuses mul and sum reduction, which is critical in the cases while the feature size on last dimension is non-trivial (e.g. multi-head attention in Transformer-like models).

The following is an example showing how GSDDMM works:

>>> import dgl
>>> import torch as th
>>> import dgl.ops as F
>>> g = dgl.graph(([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]))  # 3 nodes, 6 edges
>>> x = th.ones(3, 2, requires_grad=True)
>>> x
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]], requires_grad=True)
>>> y = th.arange(1, 7).float().view(3, 2).requires_grad_()
>>> y
tensor([[1., 2.],
        [3., 4.],
        [5., 6.]], requires_grad=True)
>>> e = th.ones(6, 1, 2, requires_grad=True) * 2
tensor([[[2., 2.]],
        [[2., 2.]],
        [[2., 2.]],
        [[2., 2.]],
        [[2., 2.]],
        [[2., 2.]]], grad_fn=<MulBackward0>)
>>> out1 = F.u_div_v(g, x, y)
tensor([[1.0000, 0.5000],
        [0.3333, 0.2500],
        [0.2000, 0.1667],
        [0.3333, 0.2500],
        [0.2000, 0.1667],
        [0.2000, 0.1667]], grad_fn=<GSDDMMBackward>)
>>> out1.sum().backward()
>>> x.grad
tensor([[1.5333, 0.9167],
        [0.5333, 0.4167],
        [0.2000, 0.1667]])
>>> y.grad
tensor([[-1.0000, -0.2500],
        [-0.2222, -0.1250],
        [-0.1200, -0.0833]])
>>> out2 = F.e_sub_v(g, e, y)
>>> out2
tensor([[[ 1.,  0.]],
        [[-1., -2.]],
        [[-3., -4.]],
        [[-1., -2.]],
        [[-3., -4.]],
        [[-3., -4.]]], grad_fn=<GSDDMMBackward>)
>>> out3 = F.copy_v(g, y)
>>> out3
tensor([[1., 2.],
        [3., 4.],
        [5., 6.],
        [3., 4.],
        [5., 6.],
        [5., 6.]], grad_fn=<GSDDMMBackward>)
>>> out4 = F.u_dot_v(g, x, y)
>>> out4  # the last dimension was reduced to size 1.
tensor([[ 3.],
        [ 7.],
        [11.],
        [ 7.],
        [11.],
        [11.]], grad_fn=<GSDDMMBackward>)

gsddmm(g, op, lhs_data, rhs_data[, ...])

Generalized Sampled-Dense-Dense Matrix Multiplication interface.

u_add_v(g, x, y)

Generalized SDDMM function.

u_sub_v(g, x, y)

Generalized SDDMM function.

u_mul_v(g, x, y)

Generalized SDDMM function.

u_dot_v(g, x, y)

Generalized SDDMM function.

u_div_v(g, x, y)

Generalized SDDMM function.

u_add_e(g, x, y)

Generalized SDDMM function.

u_sub_e(g, x, y)

Generalized SDDMM function.

u_mul_e(g, x, y)

Generalized SDDMM function.

u_dot_e(g, x, y)

Generalized SDDMM function.

u_div_e(g, x, y)

Generalized SDDMM function.

e_add_v(g, x, y)

Generalized SDDMM function.

e_sub_v(g, x, y)

Generalized SDDMM function.

e_mul_v(g, x, y)

Generalized SDDMM function.

e_dot_v(g, x, y)

Generalized SDDMM function.

e_div_v(g, x, y)

Generalized SDDMM function.

v_add_u(g, x, y)

Generalized SDDMM function.

v_sub_u(g, x, y)

Generalized SDDMM function.

v_mul_u(g, x, y)

Generalized SDDMM function.

v_dot_u(g, x, y)

Generalized SDDMM function.

v_div_u(g, x, y)

Generalized SDDMM function.

e_add_u(g, x, y)

Generalized SDDMM function.

e_sub_u(g, x, y)

Generalized SDDMM function.

e_mul_u(g, x, y)

Generalized SDDMM function.

e_dot_u(g, x, y)

Generalized SDDMM function.

e_div_u(g, x, y)

Generalized SDDMM function.

v_add_e(g, x, y)

Generalized SDDMM function.

v_sub_e(g, x, y)

Generalized SDDMM function.

v_mul_e(g, x, y)

Generalized SDDMM function.

v_dot_e(g, x, y)

Generalized SDDMM function.

v_div_e(g, x, y)

Generalized SDDMM function.

copy_u(g, x)

Generalized SDDMM function that copies source node features to edges.

copy_v(g, x)

Generalized SDDMM function that copies destination node features to edges.

Like GSpMM, GSDDMM operators support both homogeneous and bipartite graph.

Segment Reduce Module

DGL provide operators to reduce value tensor along the first dimension by segments.

segment_reduce(seglen, value[, reducer])

Segment reduction operator.

GatherMM and SegmentMM Module

SegmentMM: DGL provide operators to perform matrix multiplication according to segments.

GatherMM: DGL provide operators to gather data according to the given indices and perform matrix multiplication.

gather_mm(a, b, *, idx_b)

Gather data according to the given indices and perform matrix multiplication.

segment_mm(a, b, seglen_a)

Performs matrix multiplication according to segments.

Supported Data types

Operators defined in dgl.ops support floating point data types, i.e. the operands must be half (float16) /float/double tensors. The input tensors must have the same data type (if one input tensor has type float16 and the other input tensor has data type float32, user must convert one of them to align with the other one).

float16 data type support is disabled by default as it has a minimum GPU compute capacity requirement of sm_53 (Pascal, Volta, Turing and Ampere architectures).

User can enable float16 for mixed precision training by compiling DGL from source (see Mixed Precision Training tutorial for details).

Relation with Message Passing APIs

dgl.update_all and dgl.apply_edges calls with built-in message/reduce functions would be dispatched into function calls of operators defined in dgl.ops:

>>> import dgl
>>> import torch as th
>>> import dgl.ops as F
>>> import dgl.function as fn
>>> g = dgl.rand_graph(100, 1000)   # create a DGLGraph with 100 nodes and 1000 edges.
>>> x = th.rand(100, 20)            # node features.
>>> e = th.rand(1000, 20)
>>>
>>> # dgl.update_all + builtin functions
>>> g.srcdata['x'] = x              # srcdata is the same as ndata for graphs with one node type.
>>> g.edata['e'] = e
>>> g.update_all(fn.u_mul_e('x', 'e', 'm'), fn.sum('m', 'y'))
>>> y = g.dstdata['y']              # dstdata is the same as ndata for graphs with one node type.
>>>
>>> # use GSpMM operators defined in dgl.ops directly
>>> y = F.u_mul_e_sum(g, x, e)
>>>
>>> # dgl.apply_edges + builtin functions
>>> g.srcdata['x'] = x
>>> g.dstdata['y'] = y
>>> g.apply_edges(fn.u_dot_v('x', 'y', 'z'))
>>> z = g.edata['z']
>>>
>>> # use GSDDMM operators defined in dgl.ops directly
>>> z = F.u_dot_v(g, x, y)

It up to user to decide whether to use message-passing APIs or GSpMM/GSDDMM operators, and both of them have the same efficiency. Programs written in message-passing APIs look more like DGL-style but in some cases calling GSpMM/GSDDMM operators is more concise.

Note that on PyTorch all operators defined in dgl.ops support higher-order gradients, so as message passing APIs because they entirely depend on these operators.