dgl.ops.segment_mm

dgl.ops.segment_mm(a, b, seglen_a)[source]

Performs matrix multiplication according to segments.

Suppose seglen_a == [10, 5, 0, 3], the operator will perform four matrix multiplications:

a[0:10] @ b[0], a[10:15] @ b[1],
a[15:15] @ b[2], a[15:18] @ b[3]
Parameters:
  • a (Tensor) – The left operand, 2-D tensor of shape (N, D1)

  • b (Tensor) – The right operand, 3-D tensor of shape (R, D1, D2)

  • seglen_a (Tensor) – An integer tensor of shape (R,). Each element is the length of segments of input a. The summation of all elements must be equal to N.

Returns:

The output dense matrix of shape (N, D2)

Return type:

Tensor