Source code for dgl.sparse.matmul

"""Matmul ops for SparseMatrix"""
# pylint: disable=invalid-name
from typing import Union

import torch

from .diag_matrix import diag, DiagMatrix

from .sparse_matrix import SparseMatrix, val_like

__all__ = ["spmm", "bspmm", "spspmm", "matmul"]


[docs]def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: """Multiplies a sparse matrix by a dense matrix, equivalent to ``A @ X``. Parameters ---------- A : SparseMatrix or DiagMatrix Sparse matrix of shape ``(L, M)`` with scalar values X : torch.Tensor Dense matrix of shape ``(M, N)`` or ``(M)`` Returns ------- torch.Tensor The dense matrix of shape ``(L, N)`` or ``(L)`` Examples -------- >>> indices = torch.tensor([[0, 1, 1], [1, 0, 1]]) >>> val = torch.randn(len(row)) >>> A = dglsp.spmatrix(indices, val) >>> X = torch.randn(2, 3) >>> result = dglsp.spmm(A, X) >>> type(result) <class 'torch.Tensor'> >>> result.shape torch.Size([2, 3]) """ assert isinstance( A, (SparseMatrix, DiagMatrix) ), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}." assert isinstance( X, torch.Tensor ), f"Expect arg2 to be a torch.Tensor, got {type(X)}." # The input is a DiagMatrix. Cast it to SparseMatrix if not isinstance(A, SparseMatrix): A = A.to_sparse() return torch.ops.dgl_sparse.spmm(A.c_sparse_matrix, X)
[docs]def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: """Multiplies a sparse matrix by a dense matrix by batches, equivalent to ``A @ X``. Parameters ---------- A : SparseMatrix or DiagMatrix Sparse matrix of shape ``(L, M)`` with vector values of length ``K`` X : torch.Tensor Dense matrix of shape ``(M, N, K)`` Returns ------- torch.Tensor Dense matrix of shape ``(L, N, K)`` Examples -------- >>> indices = torch.tensor([[0, 1, 1], [1, 0, 2]]) >>> val = torch.randn(len(row), 2) >>> A = dglsp.spmatrix(indices, val, shape=(3, 3)) >>> X = torch.randn(3, 3, 2) >>> result = dglsp.bspmm(A, X) >>> type(result) <class 'torch.Tensor'> >>> result.shape torch.Size([3, 3, 2]) """ assert isinstance( A, (SparseMatrix, DiagMatrix) ), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}." assert isinstance( X, torch.Tensor ), f"Expect arg2 to be a torch.Tensor, got {type(X)}." return spmm(A, X)
def _diag_diag_mm(A: DiagMatrix, B: DiagMatrix) -> DiagMatrix: """Internal function for multiplying a diagonal matrix by a diagonal matrix. Parameters ---------- A : DiagMatrix Diagonal matrix of shape ``(L, M)`` B : DiagMatrix Diagonal matrix of shape ``(M, N)`` Returns ------- DiagMatrix Diagonal matrix of shape ``(L, N)`` """ M, N = A.shape N, P = B.shape common_diag_len = min(M, N, P) new_diag_len = min(M, P) diag_val = torch.zeros(new_diag_len) diag_val[:common_diag_len] = ( A.val[:common_diag_len] * B.val[:common_diag_len] ) return diag(diag_val.to(A.device), (M, P)) def _sparse_diag_mm(A, D): """Internal function for multiplying a sparse matrix by a diagonal matrix. Parameters ---------- A : SparseMatrix Sparse matrix of shape ``(L, M)`` D : DiagMatrix Diagonal matrix of shape ``(M, N)`` Returns ------- SparseMatrix Sparse matrix of shape ``(L, N)`` """ assert ( A.shape[1] == D.shape[0] ), f"The second dimension of SparseMatrix should be equal to the first \ dimension of DiagMatrix in matmul(SparseMatrix, DiagMatrix), but the \ shapes of SparseMatrix and DiagMatrix are {A.shape} and {D.shape} \ respectively." assert ( D.shape[0] == D.shape[1] ), f"The DiagMatrix should be a square in matmul(SparseMatrix, DiagMatrix) \ but got {D.shape}." return val_like(A, D.val[A.col] * A.val) def _diag_sparse_mm(D, A): """Internal function for multiplying a diagonal matrix by a sparse matrix. Parameters ---------- D : DiagMatrix Diagonal matrix of shape ``(L, M)`` A : SparseMatrix Sparse matrix of shape ``(M, N)`` Returns ------- SparseMatrix Sparse matrix of shape ``(L, N)`` """ assert ( D.shape[1] == A.shape[0] ), f"The second dimension of DiagMatrix should be equal to the first \ dimension of SparseMatrix in matmul(DiagMatrix, SparseMatrix), but the \ shapes of DiagMatrix and SparseMatrix are {D.shape} and {A.shape} \ respectively." assert ( D.shape[0] == D.shape[1] ), f"The DiagMatrix should be a square in matmul(DiagMatrix, SparseMatrix) \ but got {D.shape}." return val_like(A, D.val[A.row] * A.val)
[docs]def spspmm( A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix] ) -> Union[SparseMatrix, DiagMatrix]: """Multiplies a sparse matrix by a sparse matrix, equivalent to ``A @ B``. The non-zero values of the two sparse matrices must be 1D. Parameters ---------- A : SparseMatrix or DiagMatrix Sparse matrix of shape ``(L, M)`` B : SparseMatrix or DiagMatrix Sparse matrix of shape ``(M, N)`` Returns ------- SparseMatrix or DiagMatrix Matrix of shape ``(L, N)``. It is a DiagMatrix object if both matrices are DiagMatrix objects, otherwise a SparseMatrix object. Examples -------- >>> indices1 = torch.tensor([[0, 1, 1], [1, 0, 1]]) >>> val1 = torch.ones(len(row1)) >>> A = dglsp.spmatrix(indices1, val1) >>> indices2 = torch.tensor([[0, 1, 1], [0, 2, 1]]) >>> val2 = torch.ones(len(row2)) >>> B = dglsp.spmatrix(indices2, val2) >>> dglsp.spspmm(A, B) SparseMatrix(indices=tensor([[0, 0, 1, 1, 1], [1, 2, 0, 1, 2]]), values=tensor([1., 1., 1., 1., 1.]), shape=(2, 3), nnz=5) """ assert isinstance( A, (SparseMatrix, DiagMatrix) ), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A)}." assert isinstance( B, (SparseMatrix, DiagMatrix) ), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(B)}." if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix): return _diag_diag_mm(A, B) if isinstance(A, DiagMatrix): return _diag_sparse_mm(A, B) if isinstance(B, DiagMatrix): return _sparse_diag_mm(A, B) return SparseMatrix( torch.ops.dgl_sparse.spspmm(A.c_sparse_matrix, B.c_sparse_matrix) )
[docs]def matmul( A: Union[torch.Tensor, SparseMatrix, DiagMatrix], B: Union[torch.Tensor, SparseMatrix, DiagMatrix], ) -> Union[torch.Tensor, SparseMatrix, DiagMatrix]: """Multiplies two dense/sparse/diagonal matrices, equivalent to ``A @ B``. The supported combinations are shown as follows. +--------------+--------+------------+--------------+ | A \\ B | Tensor | DiagMatrix | SparseMatrix | +--------------+--------+------------+--------------+ | Tensor | ✅ | 🚫 | 🚫 | +--------------+--------+------------+--------------+ | SparseMatrix | ✅ | ✅ | ✅ | +--------------+--------+------------+--------------+ | DiagMatrix | ✅ | ✅ | ✅ | +--------------+--------+------------+--------------+ * If both matrices are torch.Tensor, it calls \ :func:`torch.matmul()`. The result is a dense matrix. * If both matrices are sparse or diagonal, it calls \ :func:`dgl.sparse.spspmm`. The result is a sparse matrix. * If :attr:`A` is sparse or diagonal while :attr:`B` is dense, it \ calls :func:`dgl.sparse.spmm`. The result is a dense matrix. * The operator supports batched sparse-dense matrix multiplication. In \ this case, the sparse or diagonal matrix :attr:`A` should have shape \ ``(L, M)``, where the non-zero values have a batch dimension ``K``. \ The dense matrix :attr:`B` should have shape ``(M, N, K)``. The output \ is a dense matrix of shape ``(L, N, K)``. * Sparse-sparse matrix multiplication does not support batched computation. Parameters ---------- A : torch.Tensor, SparseMatrix or DiagMatrix The first matrix. B : torch.Tensor, SparseMatrix, or DiagMatrix The second matrix. Returns ------- torch.Tensor, SparseMatrix or DiagMatrix The result matrix Examples -------- Multiplies a diagonal matrix with a dense matrix. >>> val = torch.randn(3) >>> A = dglsp.diag(val) >>> B = torch.randn(3, 2) >>> result = dglsp.matmul(A, B) >>> type(result) <class 'torch.Tensor'> >>> result.shape torch.Size([3, 2]) Multiplies a sparse matrix with a dense matrix. >>> indices = torch.tensor([[0, 1, 1], [1, 0, 1]]) >>> val = torch.randn(len(row)) >>> A = dglsp.spmatrix(indices, val) >>> X = torch.randn(2, 3) >>> result = dglsp.matmul(A, X) >>> type(result) <class 'torch.Tensor'> >>> result.shape torch.Size([2, 3]) Multiplies a sparse matrix with a sparse matrix. >>> indices1 = torch.tensor([[0, 1, 1], [1, 0, 1]]) >>> val1 = torch.ones(len(row1)) >>> A = dglsp.spmatrix(indices1, val1) >>> indices2 = torch.tensor([[0, 1, 1], [0, 2, 1]]) >>> val2 = torch.ones(len(row2)) >>> B = dglsp.spmatrix(indices2, val2) >>> result = dglsp.matmul(A, B) >>> type(result) <class 'dgl.sparse.sparse_matrix.SparseMatrix'> >>> result.shape (2, 3) """ assert isinstance(A, (torch.Tensor, SparseMatrix, DiagMatrix)), ( f"Expect arg1 to be a torch.Tensor, SparseMatrix, or DiagMatrix object," f"got {type(A)}." ) assert isinstance(B, (torch.Tensor, SparseMatrix, DiagMatrix)), ( f"Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix" f"object, got {type(B)}." ) if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): return torch.matmul(A, B) assert not isinstance(A, torch.Tensor), ( f"Expect arg2 to be a torch Tensor if arg 1 is torch Tensor, " f"got {type(B)}." ) if isinstance(B, torch.Tensor): return spmm(A, B) if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix): return _diag_diag_mm(A, B) return spspmm(A, B)
SparseMatrix.__matmul__ = matmul DiagMatrix.__matmul__ = matmul