Source code for dgl.sparse.softmax

"""Softmax op for SparseMatrix"""
# pylint: disable=invalid-name, W0622

import torch

from .sparse_matrix import SparseMatrix

__all__ = ["softmax"]


[docs]def softmax(input: SparseMatrix) -> SparseMatrix: """Applies row-wise softmax to the non-zero elements of the sparse matrix. Equivalently, applies softmax to the non-zero elements of the sparse matrix along the column (``dim=1``) dimension. If :attr:`input.val` takes shape ``(nnz, D)``, then the output matrix :attr:`output` and :attr:`output.val` take the same shape as :attr:`input` and :attr:`input.val`. :attr:`output.val[:, i]` is calculated based on :attr:`input.val[:, i]`. Parameters ---------- input : SparseMatrix The input sparse matrix Returns ------- SparseMatrix The output sparse matrix Examples -------- Case1: matrix with values of shape (nnz) >>> indices = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 0]]) >>> nnz = len(row) >>> val = torch.arange(nnz).float() >>> A = dglsp.spmatrix(indices, val) >>> dglsp.softmax(A) SparseMatrix(indices=tensor([[0, 0, 1, 2], [1, 2, 2, 0]]), values=tensor([0.2689, 0.7311, 1.0000, 1.0000]), shape=(3, 3), nnz=4) Case2: matrix with values of shape (nnz, D) >>> indices = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 0]]) >>> val = torch.tensor([[0., 7.], [1., 3.], [2., 2.], [3., 1.]]) >>> A = dglsp.spmatrix(indices, val) >>> dglsp.softmax(A) SparseMatrix(indices=tensor([[0, 0, 1, 2], [1, 2, 2, 0]]), values=tensor([[0.2689, 0.9820], [0.7311, 0.0180], [1.0000, 1.0000], [1.0000, 1.0000]]), shape=(3, 3), nnz=4, val_size=(2,)) """ return SparseMatrix(torch.ops.dgl_sparse.softmax(input.c_sparse_matrix))
SparseMatrix.softmax = softmax