dgl.ops.gather_mm¶
-
dgl.ops.
gather_mm
(a, b, *, idx_b)[source]¶ Gather data according to the given indices and perform matrix multiplication.
Let the result tensor be
c
, the operator conducts the following computation:c[i] = a[i] @ b[idx_b[i]] , where len(c) == len(idx_b)
- Parameters
a (Tensor) – A 2-D tensor of shape
(N, D1)
b (Tensor) – A 3-D tensor of shape
(R, D1, D2)
idx_b (Tensor, optional) – An 1-D integer tensor of shape
(N,)
.
- Returns
The output dense matrix of shape
(N, D2)
- Return type
Tensor