dgl.ops.segment_mm

dgl.ops.segment_mm(a, b, seglen_a)[source]

Performs matrix multiplication according to segments.

Suppose seglen_a == [10, 5, 0, 3], the operator will perform four matrix multiplications:

a[0:10] @ b[0], a[10:15] @ b[1],
a[15:15] @ b[2], a[15:18] @ b[3]
Parameters
  • a (Tensor) – The left operand, 2-D tensor of shape (N, D1)

  • b (Tensor) – The right operand, 3-D tensor of shape (R, D1, D2)

  • seglen_a (Tensor) – An integer tensor of shape (R,). Each element is the length of segments of input a. The summation of all elements must be equal to N.

Returns

The output dense matrix of shape (N, D2)

Return type

Tensor