dgl.sparse.matmul

dgl.sparse.matmul(A: Union[torch.Tensor, dgl.sparse.sparse_matrix.SparseMatrix], B: Union[torch.Tensor, dgl.sparse.sparse_matrix.SparseMatrix])Union[torch.Tensor, dgl.sparse.sparse_matrix.SparseMatrix][source]

Multiplies two dense/sparse matrices, equivalent to A @ B.

This function does not support the case where A is a torch.Tensor and B is a SparseMatrix.

  • If both matrices are torch.Tensor, it calls torch.matmul(). The result is a dense matrix.

  • If both matrices are sparse, it calls dgl.sparse.spspmm(). The result is a sparse matrix.

  • If A is sparse while B is dense, it calls dgl.sparse.spmm(). The result is a dense matrix.

  • The operator supports batched sparse-dense matrix multiplication. In this case, the sparse matrix A should have shape (L, M), where the non-zero values have a batch dimension K. The dense matrix B should have shape (M, N, K). The output is a dense matrix of shape (L, N, K).

  • Sparse-sparse matrix multiplication does not support batched computation.

Parameters
Returns

The result matrix

Return type

torch.Tensor or SparseMatrix

Examples

Multiplies a diagonal matrix with a dense matrix.

>>> val = torch.randn(3)
>>> A = dglsp.diag(val)
>>> B = torch.randn(3, 2)
>>> result = dglsp.matmul(A, B)
>>> type(result)
<class 'torch.Tensor'>
>>> result.shape
torch.Size([3, 2])

Multiplies a sparse matrix with a dense matrix.

>>> indices = torch.tensor([[0, 1, 1], [1, 0, 1]])
>>> val = torch.randn(indices.shape[1])
>>> A = dglsp.spmatrix(indices, val)
>>> X = torch.randn(2, 3)
>>> result = dglsp.matmul(A, X)
>>> type(result)
<class 'torch.Tensor'>
>>> result.shape
torch.Size([2, 3])

Multiplies a sparse matrix with a sparse matrix.

>>> indices1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
>>> val1 = torch.ones(indices1.shape[1])
>>> A = dglsp.spmatrix(indices1, val1)
>>> indices2 = torch.tensor([[0, 1, 1], [0, 2, 1]])
>>> val2 = torch.ones(indices2.shape[1])
>>> B = dglsp.spmatrix(indices2, val2)
>>> result = dglsp.matmul(A, B)
>>> type(result)
<class 'dgl.sparse.sparse_matrix.SparseMatrix'>
>>> result.shape
(2, 3)