dgl.sparse.bsddmm¶
-
dgl.sparse.
bsddmm
(A: dgl.sparse.sparse_matrix.SparseMatrix, X1: torch.Tensor, X2: torch.Tensor) → dgl.sparse.sparse_matrix.SparseMatrix[source]¶ Sampled-Dense-Dense Matrix Multiplication (SDDMM) by batches.
sddmm
matrix-multiplies two dense matricesX1
andX2
, then elementwise-multiplies the result with sparse matrixA
at the nonzero locations.Mathematically
sddmm
is formulated as:\[out = (X1 @ X2) * A\]The batch dimension is the last dimension for input dense matrices. In particular, if the sparse matrix has scalar non-zero values, it will be broadcasted for bsddmm.
- Parameters
A (SparseMatrix) – Sparse matrix of shape
(L, N)
with scalar values or vector values of lengthK
X1 (Tensor) – Dense matrix of shape
(L, M, K)
X2 (Tensor) – Dense matrix of shape
(M, N, K)
- Returns
Sparse matrix of shape
(L, N)
with vector values of lengthK
- Return type
Examples
>>> indices = torch.tensor([[1, 1, 2], [2, 3, 3]]) >>> val = torch.arange(1, 4).float() >>> A = dglsp.spmatrix(indices, val, (3, 4)) >>> X1 = torch.arange(0, 3 * 5 * 2).view(3, 5, 2).float() >>> X2 = torch.arange(0, 5 * 4 * 2).view(5, 4, 2).float() >>> dglsp.bsddmm(A, X1, X2) SparseMatrix(indices=tensor([[1, 1, 2], [2, 3, 3]]), values=tensor([[1560., 1735.], [3400., 3770.], [8400., 9105.]]), shape=(3, 4), nnz=3, val_size=(2,))