dgl.ops.segment_mm¶
-
dgl.ops.
segment_mm
(a, b, seglen_a)[source]¶ Performs matrix multiplication according to segments.
Suppose
seglen_a == [10, 5, 0, 3]
, the operator will perform four matrix multiplications:a[0:10] @ b[0], a[10:15] @ b[1], a[15:15] @ b[2], a[15:18] @ b[3]
- Parameters
a (Tensor) – The left operand, 2-D tensor of shape
(N, D1)
b (Tensor) – The right operand, 3-D tensor of shape
(R, D1, D2)
seglen_a (Tensor) – An integer tensor of shape
(R,)
. Each element is the length of segments of inputa
. The summation of all elements must be equal toN
.
- Returns
The output dense matrix of shape
(N, D2)
- Return type
Tensor