dgl.sparse.SparseMatrix.reduce

SparseMatrix.reduce(dim: Optional[int] = None, rtype: str = 'sum')

Computes the reduction of non-zero values of the input sparse matrix along the given dimension dim.

The reduction does not count zero elements. If the row or column to be reduced does not have any non-zero elements, the result will be 0.

Parameters
  • input (SparseMatrix) – The input sparse matrix

  • dim (int, optional) –

    The dimension to reduce, must be either 0 (by rows) or 1 (by columns) or None (on both rows and columns simultaneously)

    If dim is None, it reduces both the rows and the columns in the sparse matrix, producing a tensor of shape input.val.shape[1:]. Otherwise, it reduces on the row (dim=0) or column (dim=1) dimension, producing a tensor of shape (input.shape[1],) + input.val.shape[1:] or (input.shape[0],) + input.val.shape[1:].

  • rtype (str, optional) – Reduction type, one of ['sum', 'smin', 'smax', 'smean', 'sprod'], representing taking the sum, minimum, maximum, mean, and product of the non-zero elements

Returns

Reduced tensor

Return type

torch.Tensor

Examples

Case1: scalar-valued sparse matrix

>>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])
>>> val = torch.tensor([1, 1, 2])
>>> A = dglsp.spmatrix(indices, val, shape=(4, 3))
>>> dglsp.reduce(A, rtype='sum')
tensor(4)
>>> dglsp.reduce(A, 0, 'sum')
tensor([2, 0, 2])
>>> dglsp.reduce(A, 1, 'sum')
tensor([1, 3, 0, 0])
>>> dglsp.reduce(A, 0, 'smax')
tensor([1, 0, 2])
>>> dglsp.reduce(A, 1, 'smin')
tensor([1, 1, 0, 0])

Case2: vector-valued sparse matrix

>>> indices = torch.tensor([[0, 1, 1], [0, 0, 2]])
>>> val = torch.tensor([[1., 2.], [2., 1.], [2., 2.]])
>>> A = dglsp.spmatrix(indices, val, shape=(4, 3))
>>> dglsp.reduce(A, rtype='sum')
tensor([5., 5.])
>>> dglsp.reduce(A, 0, 'sum')
tensor([[3., 3.],
        [0., 0.],
        [2., 2.]])
>>> dglsp.reduce(A, 1, 'smin')
tensor([[1., 2.],
        [2., 1.],
        [0., 0.],
        [0., 0.]])
>>> dglsp.reduce(A, 0, 'smean')
tensor([[1.5000, 1.5000],
        [0.0000, 0.0000],
        [2.0000, 2.0000]])