dgl.ops.gather_mmο
- dgl.ops.gather_mm(a, b, *, idx_b)[source]ο
Gather data according to the given indices and perform matrix multiplication.
Let the result tensor be
c
, the operator conducts the following computation:c[i] = a[i] @ b[idx_b[i]] , where len(c) == len(idx_b)
- Parameters:
a (Tensor) β A 2-D tensor of shape
(N, D1)
b (Tensor) β A 3-D tensor of shape
(R, D1, D2)
idx_b (Tensor, optional) β An 1-D integer tensor of shape
(N,)
.
- Returns:
The output dense matrix of shape
(N, D2)
- Return type:
Tensor