dgl.ops.gather_mm

dgl.ops.gather_mm(a, b, *, idx_b)[source]

Gather data according to the given indices and perform matrix multiplication.

Let the result tensor be c, the operator conducts the following computation:

c[i] = a[i] @ b[idx_b[i]] , where len(c) == len(idx_b)

Parameters
  • a (Tensor) – A 2-D tensor of shape (N, D1)

  • b (Tensor) – A 3-D tensor of shape (R, D1, D2)

  • idx_b (Tensor, optional) – An 1-D integer tensor of shape (N,).

Returns

The output dense matrix of shape (N, D2)

Return type

Tensor