Quickstart

The tutorial provides a quick walkthrough of the classes and operators provided by the dgl.sparse package.

Open In Colab GitHub

[1]:
# Install the required packages.

import os
# Uncomment following commands to download Pytorch and DGL
# !pip install torch==2.0.0+cpu torchvision==0.15.1+cpu torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cpu > /dev/null
# !pip install  dgl==1.1.0 -f https://data.dgl.ai/wheels/repo.html > /dev/null
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"


try:
    import dgl.sparse as dglsp
    installed = True
except ImportError:
    installed = False
print("DGL installed!" if installed else "DGL not found!")
DGL installed!

Sparse Matrix

The core abstraction of DGL’s sparse package is the SparseMatrix class. Compared with other sparse matrix libraries (such as scipy.sparse and torch.sparse), DGL’s SparseMatrix is specialized for the deep learning workloads on structure data (e.g., Graph Neural Networks), with the following features:

  • Auto sparse format. Don’t bother choosing between different sparse formats. There is only one SparseMatrix and it will select the best format for the operation to be performed.

  • Non-zero elements can be scalar or vector. Easy for modeling relations (e.g., edges) by vector representation.

  • Fully PyTorch compatible. The package is built upon PyTorch and is natively compatible with other tools in the PyTorch ecosystem.

Creating a DGL Sparse Matrix

The simplest way to create a sparse matrix is using the spmatrix API by providing the indices of the non-zero elements. The indices are stored in a tensor of shape (2, nnz), where the i-th non-zero element is stored at position (indices[0][i], indices[1][i]). The code below creates a 3x3 sparse matrix.

[2]:
import torch
import dgl.sparse as dglsp

i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
A = dglsp.spmatrix(i)  # 1.0 is default value for nnz elements.

print(A)
print("")
print("In dense format:")
print(A.to_dense())
SparseMatrix(indices=tensor([[1, 1, 2],
                             [0, 2, 0]]),
             values=tensor([1., 1., 1.]),
             shape=(3, 3), nnz=3)

In dense format:
tensor([[0., 0., 0.],
        [1., 0., 1.],
        [1., 0., 0.]])

If not specified, the shape is inferred automatically from the indices but you can specify it explicitly too.

[3]:
i = torch.tensor([[0, 0, 1],
                  [0, 2, 0]])

A1 = dglsp.spmatrix(i)
print(f"Implicit Shape: {A1.shape}")
print(A1.to_dense())
print("")

A2 = dglsp.spmatrix(i, shape=(3, 3))
print(f"Explicit Shape: {A2.shape}")
print(A2.to_dense())
Implicit Shape: (2, 3)
tensor([[1., 0., 1.],
        [1., 0., 0.]])

Explicit Shape: (3, 3)
tensor([[1., 0., 1.],
        [1., 0., 0.],
        [0., 0., 0.]])

Both scalar values and vector values can be set for nnz elements in Sparse Matrix.

[4]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
# The length of the value should match the nnz elements represented by the
# sparse matrix format.
scalar_val = torch.tensor([1., 2., 3.])
vector_val = torch.tensor([[1., 1.], [2., 2.], [3., 3.]])

print("-----Scalar Values-----")
A = dglsp.spmatrix(i, scalar_val)
print(A)
print("")
print("In dense format:")
print(A.to_dense())
print("")

print("-----Vector Values-----")
A = dglsp.spmatrix(i, vector_val)
print(A)
print("")
print("In dense format:")
print(A.to_dense())
-----Scalar Values-----
SparseMatrix(indices=tensor([[1, 1, 2],
                             [0, 2, 0]]),
             values=tensor([1., 2., 3.]),
             shape=(3, 3), nnz=3)

In dense format:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])

-----Vector Values-----
SparseMatrix(indices=tensor([[1, 1, 2],
                             [0, 2, 0]]),
             values=tensor([[1., 1.],
                            [2., 2.],
                            [3., 3.]]),
             shape=(3, 3), nnz=3, val_size=(2,))

In dense format:
tensor([[[0., 0.],
         [0., 0.],
         [0., 0.]],

        [[1., 1.],
         [0., 0.],
         [2., 2.]],

        [[3., 3.],
         [0., 0.],
         [0., 0.]]])

Duplicated indices

[5]:
i = torch.tensor([[0, 0, 0, 1],
                  [0, 2, 2, 0]])
val = torch.tensor([1., 2., 3., 4])
A = dglsp.spmatrix(i, val)
print(A)
print(f"Whether A contains duplicate indices: {A.has_duplicate()}")
print("")

B = A.coalesce()
print(B)
print(f"Whether B contains duplicate indices: {B.has_duplicate()}")
SparseMatrix(indices=tensor([[0, 0, 0, 1],
                             [0, 2, 2, 0]]),
             values=tensor([1., 2., 3., 4.]),
             shape=(2, 3), nnz=4)
Whether A contains duplicate indices: True

SparseMatrix(indices=tensor([[0, 0, 1],
                             [0, 2, 0]]),
             values=tensor([1., 5., 4.]),
             shape=(2, 3), nnz=3)
Whether B contains duplicate indices: False

val_like

You can create a new sparse matrix by retaining the non-zero indices of a given sparse matrix but with different non-zero values.

[6]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A = dglsp.spmatrix(i, val)

new_val = torch.tensor([4., 5., 6.])
B = dglsp.val_like(A, new_val)
print(B)
SparseMatrix(indices=tensor([[1, 1, 2],
                             [0, 2, 0]]),
             values=tensor([4., 5., 6.]),
             shape=(3, 3), nnz=3)

Create a sparse matrix from various sparse formats

  • from_coo(): Create a sparse matrix from COO format.

  • from_csr(): Create a sparse matrix from CSR format.

  • from_csc(): Create a sparse matrix from CSC format.

[7]:
row = torch.tensor([0, 1, 2, 2, 2])
col = torch.tensor([1, 2, 0, 1, 2])

print("-----Create from COO format-----")
A = dglsp.from_coo(row, col)
print(A)
print("")
print("In dense format:")
print(A.to_dense())
print("")

indptr = torch.tensor([0, 1, 2, 5])
indices = torch.tensor([1, 2, 0, 1, 2])

print("-----Create from CSR format-----")
A = dglsp.from_csr(indptr, indices)
print(A)
print("")
print("In dense format:")
print(A.to_dense())
print("")

print("-----Create from CSC format-----")
B = dglsp.from_csc(indptr, indices)
print(B)
print("")
print("In dense format:")
print(B.to_dense())
-----Create from COO format-----
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
                             [1, 2, 0, 1, 2]]),
             values=tensor([1., 1., 1., 1., 1.]),
             shape=(3, 3), nnz=5)

In dense format:
tensor([[0., 1., 0.],
        [0., 0., 1.],
        [1., 1., 1.]])

-----Create from CSR format-----
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
                             [1, 2, 0, 1, 2]]),
             values=tensor([1., 1., 1., 1., 1.]),
             shape=(3, 3), nnz=5)

In dense format:
tensor([[0., 1., 0.],
        [0., 0., 1.],
        [1., 1., 1.]])

-----Create from CSC format-----
SparseMatrix(indices=tensor([[1, 2, 0, 1, 2],
                             [0, 1, 2, 2, 2]]),
             values=tensor([1., 1., 1., 1., 1.]),
             shape=(3, 3), nnz=5)

In dense format:
tensor([[0., 0., 1.],
        [1., 0., 1.],
        [0., 1., 1.]])

Attributes and methods of a DGL Sparse Matrix

[8]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.tensor([1., 2., 3., 4.])
A = dglsp.spmatrix(i, val)

print(f"Shape of sparse matrix: {A.shape}")
print(f"The number of nonzero elements of sparse matrix: {A.nnz}")
print(f"Datatype of sparse matrix: {A.dtype}")
print(f"Device sparse matrix is stored on: {A.device}")
print(f"Get the values of the nonzero elements: {A.val}")
print(f"Get the row indices of the nonzero elements: {A.row}")
print(f"Get the column indices of the nonzero elements: {A.col}")
print(f"Get the coordinate (COO) representation: {A.coo()}")
print(f"Get the compressed sparse row (CSR) representation: {A.csr()}")
print(f"Get the compressed sparse column (CSC) representation: {A.csc()}")
Shape of sparse matrix: (3, 3)
The number of nonzero elements of sparse matrix: 4
Datatype of sparse matrix: torch.float32
Device sparse matrix is stored on: cpu
Get the values of the nonzero elements: tensor([1., 2., 3., 4.])
Get the row indices of the nonzero elements: tensor([0, 1, 1, 2])
Get the column indices of the nonzero elements: tensor([1, 0, 2, 0])
Get the coordinate (COO) representation: (tensor([0, 1, 1, 2]), tensor([1, 0, 2, 0]))
Get the compressed sparse row (CSR) representation: (tensor([0, 1, 3, 4]), tensor([1, 0, 2, 0]), tensor([0, 1, 2, 3]))
Get the compressed sparse column (CSC) representation: (tensor([0, 2, 3, 4]), tensor([1, 2, 0, 1]), tensor([1, 3, 0, 2]))

dtype and/or device conversion

[9]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.tensor([1., 2., 3., 4.])
A = dglsp.spmatrix(i, val)

B = A.to(device='cpu', dtype=torch.int32)
print(f"Device sparse matrix is stored on: {B.device}")
print(f"Datatype of sparse matrix: {B.dtype}")
Device sparse matrix is stored on: cpu
Datatype of sparse matrix: torch.int32

Similar to pytorch, we also provide various fine-grained APIs (Doc) for dtype and/or device conversion.

Diagonal Matrix

Diagonal Matrix is a special type of Sparse Matrix, in which the entries outside the main diagonal are all zero.

Initializing a DGL Diagonal Sparse Matrix

A DGL Diagonal Sparse Matrix can be initiate by dglsp.diag().

Identity Matrix is a special type of Diagonal Sparse Matrix, in which all the value on the diagonal are 1.0. Use dglsp.identity() to initiate a Diagonal Sparse Matrix.

[10]:
val = torch.tensor([1., 2., 3., 4.])
D = dglsp.diag(val)
print(D)

I = dglsp.identity(shape=(3, 3))
print(I)
SparseMatrix(indices=tensor([[0, 1, 2, 3],
                             [0, 1, 2, 3]]),
             values=tensor([1., 2., 3., 4.]),
             shape=(4, 4), nnz=4)
SparseMatrix(indices=tensor([[0, 1, 2],
                             [0, 1, 2]]),
             values=tensor([1., 1., 1.]),
             shape=(3, 3), nnz=3)

Operations on Sparse Matrix

  • Elementwise operations

    • A + B

    • A - B

    • A * B

    • A / B

    • A ** scalar

  • Broadcast operations

    • sp_<op>_v()

  • Reduce operations

    • reduce()

    • sum()

    • smax()

    • smin()

    • smean()

  • Matrix transformations

    • SparseMatrix.transpose() or SparseMatrix.T

    • SparseMatrix.neg()

    • SparseMatrix.inv()

  • Matrix multiplication

    • matmul()

    • sddmm()

We are using dense format to print sparse matrix in this tutorial since it is more intuitive to read.

Elementwise operations

add(A, B), equivalent to A + B

Element-wise addition on two sparse matrices, returning a sparse matrix.

[11]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A1 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A1:")
print(A1.to_dense())

i = torch.tensor([[0, 1, 2],
                  [0, 2, 1]])
val = torch.tensor([4., 5., 6.])
A2 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A2:")
print(A2.to_dense())

val = torch.tensor([-1., -2., -3.])
D1 = dglsp.diag(val)
print("D1:")
print(D1.to_dense())

val = torch.tensor([-4., -5., -6.])
D2 = dglsp.diag(val)
print("D2:")
print(D2.to_dense())

print("A1 + A2:")
print((A1 + A2).to_dense())

print("A1 + D1:")
print((A1 + D1).to_dense())

print("D1 + D2:")
print((D1 + D2).to_dense())
A1:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
A2:
tensor([[4., 0., 0.],
        [0., 0., 5.],
        [0., 6., 0.]])
D1:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
D2:
tensor([[-4.,  0.,  0.],
        [ 0., -5.,  0.],
        [ 0.,  0., -6.]])
A1 + A2:
tensor([[4., 0., 0.],
        [1., 0., 7.],
        [3., 6., 0.]])
A1 + D1:
tensor([[-1.,  0.,  0.],
        [ 1., -2.,  2.],
        [ 3.,  0., -3.]])
D1 + D2:
tensor([[-5.,  0.,  0.],
        [ 0., -7.,  0.],
        [ 0.,  0., -9.]])

sub(A, B), equivalent to A - B

Element-wise substraction on two sparse matrices, returning a sparse matrix.

[12]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A1 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A1:")
print(A1.to_dense())

i = torch.tensor([[0, 1, 2],
                  [0, 2, 1]])
val = torch.tensor([4., 5., 6.])
A2 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A2:")
print(A2.to_dense())

val = torch.tensor([-1., -2., -3.])
D1 = dglsp.diag(val)
print("D1:")
print(D1.to_dense())

val = torch.tensor([-4., -5., -6.])
D2 = dglsp.diag(val)
print("D2:")
print(D2.to_dense())

print("A1 - A2:")
print((A1 - A2).to_dense())

print("A1 - D1:")
print((A1 - D1).to_dense())

print("D1 - A1:")
print((D1 - A1).to_dense())

print("D1 - D2:")
print((D1 - D2).to_dense())
A1:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
A2:
tensor([[4., 0., 0.],
        [0., 0., 5.],
        [0., 6., 0.]])
D1:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
D2:
tensor([[-4.,  0.,  0.],
        [ 0., -5.,  0.],
        [ 0.,  0., -6.]])
A1 - A2:
tensor([[-4.,  0.,  0.],
        [ 1.,  0., -3.],
        [ 3., -6.,  0.]])
A1 - D1:
tensor([[1., 0., 0.],
        [1., 2., 2.],
        [3., 0., 3.]])
D1 - A1:
tensor([[-1.,  0.,  0.],
        [-1., -2., -2.],
        [-3.,  0., -3.]])
D1 - D2:
tensor([[3., 0., 0.],
        [0., 3., 0.],
        [0., 0., 3.]])

mul(A, B), equivalent to A * B

Element-wise multiplication on two sparse matrices or on a sparse matrix and a scalar, returning a sparse matrix.

[13]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A1 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A1:")
print(A1.to_dense())

i = torch.tensor([[0, 1, 2, 2],
                  [0, 2, 0, 1]])
val = torch.tensor([1., 2., 3., 4.])
A2 = dglsp.spmatrix(i, val, shape=(3, 3))

print("A2:")
print(A2.to_dense())

print("A1 * 3:")
print((A1 * 3).to_dense())
print("3 * A1:")
print((3 * A1).to_dense())

print("A1 * A2")
print((A1 * A2).to_dense())

val = torch.tensor([-1., -2., -3.])
D1 = dglsp.diag(val)
print("D1:")
print(D1.to_dense())

print("D1 * A2")
print((D1 * A2).to_dense())

val = torch.tensor([-4., -5., -6.])
D2 = dglsp.diag(val)
print("D2:")
print(D2.to_dense())

print("D1 * -2:")
print((D1 * -2).to_dense())
print("-2 * D1:")
print((-2 * D1).to_dense())

print("D1 * D2:")
print((D1 * D2).to_dense())
A1:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
A2:
tensor([[1., 0., 0.],
        [0., 0., 2.],
        [3., 4., 0.]])
A1 * 3:
tensor([[0., 0., 0.],
        [3., 0., 6.],
        [9., 0., 0.]])
3 * A1:
tensor([[0., 0., 0.],
        [3., 0., 6.],
        [9., 0., 0.]])
A1 * A2
tensor([[0., 0., 0.],
        [0., 0., 4.],
        [9., 0., 0.]])
D1:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
D1 * A2
tensor([[-1.,  0.,  0.],
        [ 0.,  0.,  0.],
        [ 0.,  0.,  0.]])
D2:
tensor([[-4.,  0.,  0.],
        [ 0., -5.,  0.],
        [ 0.,  0., -6.]])
D1 * -2:
tensor([[2., 0., 0.],
        [0., 4., 0.],
        [0., 0., 6.]])
-2 * D1:
tensor([[2., 0., 0.],
        [0., 4., 0.],
        [0., 0., 6.]])
D1 * D2:
tensor([[ 4.,  0.,  0.],
        [ 0., 10.,  0.],
        [ 0.,  0., 18.]])

div(A, B), equivalent to A / B

Element-wise multiplication on two sparse matrices or on a sparse matrix and a scalar, returning a sparse matrix. If both A and B are sparse matrices, both of them must have the same sparsity. And the returned matrix has the same order of non-zero entries as A.

[14]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A1 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A1:")
print(A1.to_dense())

i = torch.tensor([[1, 2, 1],
                  [0, 0, 2]])
val = torch.tensor([1., 3., 2.])
A2 = dglsp.spmatrix(i, val, shape=(3, 3))

print("A1 / 2:")
print((A1 / 2).to_dense())

print("A1 / A2")
print((A1 / A2).to_dense())

val = torch.tensor([-1., -2., -3.])
D1 = dglsp.diag(val)
print("D1:")
print(D1.to_dense())

val = torch.tensor([-4., -5., -6.])
D2 = dglsp.diag(val)
print("D2:")
print(D2.to_dense())

print("D1 / D2:")
print((D1 / D2).to_dense())

print("D1 / 2:")
print((D1 / 2).to_dense())
A1:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
A1 / 2:
tensor([[0.0000, 0.0000, 0.0000],
        [0.5000, 0.0000, 1.0000],
        [1.5000, 0.0000, 0.0000]])
A1 / A2
tensor([[0., 0., 0.],
        [1., 0., 1.],
        [1., 0., 0.]])
D1:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
D2:
tensor([[-4.,  0.,  0.],
        [ 0., -5.,  0.],
        [ 0.,  0., -6.]])
D1 / D2:
tensor([[0.2500, 0.0000, 0.0000],
        [0.0000, 0.4000, 0.0000],
        [0.0000, 0.0000, 0.5000]])
D1 / 2:
tensor([[-0.5000,  0.0000,  0.0000],
        [ 0.0000, -1.0000,  0.0000],
        [ 0.0000,  0.0000, -1.5000]])

power(A, B), equivalent to A ** B

Element-wise power of a sparse matrix and a scalar, returning a sparse matrix.

[15]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A = dglsp.spmatrix(i, val, shape=(3, 3))
print("A:")
print(A.to_dense())

print("A ** 3:")
print((A ** 3).to_dense())

val = torch.tensor([-1., -2., -3.])
D = dglsp.diag(val)
print("D:")
print(D.to_dense())

print("D1 ** 2:")
print((D1 ** 2).to_dense())
A:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
A ** 3:
tensor([[ 0.,  0.,  0.],
        [ 1.,  0.,  8.],
        [27.,  0.,  0.]])
D:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
D1 ** 2:
tensor([[1., 0., 0.],
        [0., 4., 0.],
        [0., 0., 9.]])

Broadcast operations

**sp_<op>_v(A, v)**

Broadcast operations on a sparse matrix and a vector, returning a sparse matrix. v is broadcasted to the shape of A and then the operator is applied on the non-zero values of A. <op> can be add, sub, mul, and div.

There are two cases regarding the shape of v:

  1. v is a vector of shape (1, A.shape[1]) or (A.shape[1]). In this case, v is broadcasted on the row dimension of A.

  2. v is a vector of shape (A.shape[0], 1). In this case, v is broadcasted on the column dimension of A.

[16]:
i = torch.tensor([[1, 0, 2], [0, 3, 2]])
val = torch.tensor([10, 20, 30])
A = dglsp.spmatrix(i, val, shape=(3, 4))

v1 = torch.tensor([1, 2, 3, 4])
print("A:")
print(A.to_dense())

print("v1:")
print(v1)

print("sp_add_v(A, v1)")
print(dglsp.sp_add_v(A, v1).to_dense())

v2 = v1.reshape(1, -1)
print("v2:")
print(v2)

print("sp_add_v(A, v2)")
print(dglsp.sp_add_v(A, v2).to_dense())

v3 = torch.tensor([1, 2, 3]).reshape(-1, 1)
print("v3:")
print(v3)

print("sp_add_v(A, v3)")
print(dglsp.sp_add_v(A, v3).to_dense())
A:
tensor([[ 0,  0,  0, 20],
        [10,  0,  0,  0],
        [ 0,  0, 30,  0]])
v1:
tensor([1, 2, 3, 4])
sp_add_v(A, v1)
tensor([[ 0,  0,  0, 24],
        [11,  0,  0,  0],
        [ 0,  0, 33,  0]])
v2:
tensor([[1, 2, 3, 4]])
sp_add_v(A, v2)
tensor([[ 0,  0,  0, 24],
        [11,  0,  0,  0],
        [ 0,  0, 33,  0]])
v3:
tensor([[1],
        [2],
        [3]])
sp_add_v(A, v3)
tensor([[ 0,  0,  0, 21],
        [12,  0,  0,  0],
        [ 0,  0, 33,  0]])

Reduce operations

All DGL sparse reduce operations only consider non-zero elements. To distinguish them from dense PyTorch reduce operations that consider zero elements, we use name smax, smin and smean (s stands for sparse).

[17]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.tensor([1., 2., 3., 4.])
A = dglsp.spmatrix(i, val)
print(A.T.to_dense())
print("")

# O1, O2 will have the same value.
O1 = A.reduce(0, 'sum')
O2 = A.sum(0)
print("Reduce with reducer:sum along dim = 0:")
print(O1)
print("")

# O3, O4 will have the same value.
O3 = A.reduce(0, 'smax')
O4 = A.smax(0)
print("Reduce with reducer:max along dim = 0:")
print(O3)
print("")

# O5, O6 will have the same value.
O5 = A.reduce(0, 'smin')
O6 = A.smin(0)
print("Reduce with reducer:min along dim = 0:")
print(O5)
print("")

# O7, O8 will have the same value.
O7 = A.reduce(0, 'smean')
O8 = A.smean(0)
print("Reduce with reducer:smean along dim = 0:")
print(O7)
print("")
tensor([[0., 2., 4.],
        [1., 0., 0.],
        [0., 3., 0.]])

Reduce with reducer:sum along dim = 0:
tensor([6., 1., 3.])

Reduce with reducer:max along dim = 0:
tensor([4., 1., 3.])

Reduce with reducer:min along dim = 0:
tensor([2., 1., 3.])

Reduce with reducer:smean along dim = 0:
tensor([3., 1., 3.])

[W TensorAdvancedIndexing.cpp:1615] Warning: scatter_reduce() is in beta and the API may change at any time. (function operator())

Matrix transformations

Sparse Matrix

[18]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.tensor([1., 2., 3., 4.])
A = dglsp.spmatrix(i, val)
print(A.to_dense())
print("")

print("Get transpose of sparse matrix.")
print(A.T.to_dense())
# Alias
# A.transpose()
# A.t()
print("")

print("Get a sparse matrix with the negation of the original nonzero values.")
print(A.neg().to_dense())
print("")
tensor([[0., 1., 0.],
        [2., 0., 3.],
        [4., 0., 0.]])

Get transpose of sparse matrix.
tensor([[0., 2., 4.],
        [1., 0., 0.],
        [0., 3., 0.]])

Get a sparse matrix with the negation of the original nonzero values.
tensor([[ 0., -1.,  0.],
        [-2.,  0., -3.],
        [-4.,  0.,  0.]])

Matrix multiplication

matmul(A, B), equivalent to A @ B

Matrix multiplication on sparse matrices and/or dense matrix. There are two cases as follows.

SparseMatrix @ SparseMatrix -> SparseMatrix:

For a \(L \times M\) sparse matrix A and a \(M \times N\) sparse matrix B, the shape of A @ B will be \(L \times N\) sparse matrix.

[19]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A1 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A1:")
print(A1.to_dense())

i = torch.tensor([[0, 1, 2],
                  [0, 2, 1]])
val = torch.tensor([4., 5., 6.])
A2 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A2:")
print(A2.to_dense())

val = torch.tensor([-1., -2., -3.])
D1 = dglsp.diag(val)
print("D1:")
print(D1.to_dense())

val = torch.tensor([-4., -5., -6.])
D2 = dglsp.diag(val)
print("D2:")
print(D2.to_dense())

print("A1 @ A2:")
print((A1 @ A2).to_dense())

print("A1 @ D1:")
print((A1 @ D1).to_dense())

print("D1 @ A1:")
print((D1 @ A1).to_dense())

print("D1 @ D2:")
print((D1 @ D2).to_dense())
A1:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
A2:
tensor([[4., 0., 0.],
        [0., 0., 5.],
        [0., 6., 0.]])
D1:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
D2:
tensor([[-4.,  0.,  0.],
        [ 0., -5.,  0.],
        [ 0.,  0., -6.]])
A1 @ A2:
tensor([[ 0.,  0.,  0.],
        [ 4., 12.,  0.],
        [12.,  0.,  0.]])
A1 @ D1:
tensor([[ 0.,  0.,  0.],
        [-1.,  0., -6.],
        [-3.,  0.,  0.]])
D1 @ A1:
tensor([[ 0.,  0.,  0.],
        [-2.,  0., -4.],
        [-9.,  0.,  0.]])
D1 @ D2:
tensor([[ 4.,  0.,  0.],
        [ 0., 10.,  0.],
        [ 0.,  0., 18.]])

SparseMatrix @ Tensor -> Tensor:

For a \(L \times M\) sparse matrix A and a \(M \times N\) dense matrix B, the shape of A @ B will be \(L \times N\) dense matrix.

[20]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A = dglsp.spmatrix(i, val, shape=(3, 3))
print("A:")
print(A.to_dense())

val = torch.tensor([-1., -2., -3.])
D = dglsp.diag(val)
print("D:")
print(D.to_dense())

X = torch.tensor([[11., 22.], [33., 44.], [55., 66.]])
print("X:")
print(X)

print("A @ X:")
print(A @ X)

print("D @ X:")
print(D @ X)
A:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
D:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
X:
tensor([[11., 22.],
        [33., 44.],
        [55., 66.]])
A @ X:
tensor([[  0.,   0.],
        [121., 154.],
        [ 33.,  66.]])
D @ X:
tensor([[ -11.,  -22.],
        [ -66.,  -88.],
        [-165., -198.]])

This operator also supports batched sparse-dense matrix multiplication. The sparse matrix A should have shape \(L \times M\), where the non-zero values are vectors of length \(K\). The dense matrix B should have shape \(M \times N \times K\). The output is a dense matrix of shape \(L \times N \times K\).

[21]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([[1., 1.], [2., 2.], [3., 3.]])
A = dglsp.spmatrix(i, val, shape=(3, 3))
print("A:")
print(A.to_dense())

X = torch.tensor([[[1., 1.], [1., 2.]],
                  [[1., 3.], [1., 4.]],
                  [[1., 5.], [1., 6.]]])
print("X:")
print(X)

print("A @ X:")
print(A @ X)
A:
tensor([[[0., 0.],
         [0., 0.],
         [0., 0.]],

        [[1., 1.],
         [0., 0.],
         [2., 2.]],

        [[3., 3.],
         [0., 0.],
         [0., 0.]]])
X:
tensor([[[1., 1.],
         [1., 2.]],

        [[1., 3.],
         [1., 4.]],

        [[1., 5.],
         [1., 6.]]])
A @ X:
tensor([[[ 0.,  0.],
         [ 0.,  0.]],

        [[ 3., 11.],
         [ 3., 14.]],

        [[ 3.,  3.],
         [ 3.,  6.]]])

Sampled-Dense-Dense Matrix Multiplication (SDDMM)

sddmm matrix-multiplies two dense matrices X1 and X2, then elementwise-multiplies the result with sparse matrix A at the nonzero locations. This is designed for sparse matrix with scalar values.

\[out = (X_1 @ X_2) * A\]

For a \(L \times N\) sparse matrix A, a \(L \times M\) dense matrix X1 and a \(M \times N\) dense matrix X2, sddmm(A, X1, X2) will be a \(L \times N\) sparse matrix.

[22]:
i = torch.tensor([[1, 1, 2],
                  [2, 3, 3]])
val = torch.tensor([1., 2., 3.])
A = dglsp.spmatrix(i, val, (3, 4))
print("A:")
print(A.to_dense())

X1 = torch.randn(3, 5)
X2 = torch.randn(5, 4)
print("X1:")
print(X1)
print("X2:")
print(X2)

O = dglsp.sddmm(A, X1, X2)
print("dglsp.sddmm(A, X1, X2):")
print(O.to_dense())
A:
tensor([[0., 0., 0., 0.],
        [0., 0., 1., 2.],
        [0., 0., 0., 3.]])
X1:
tensor([[-0.0497, -0.8706, -0.0910,  1.6743,  0.6626],
        [ 0.0915, -0.4665, -0.2994, -0.0594,  0.2922],
        [ 0.6908, -0.1638, -0.8707, -1.4926,  0.1577]])
X2:
tensor([[-1.3606,  0.4170,  0.2496,  0.8273],
        [ 0.5804, -0.4700, -0.5544, -0.4639],
        [ 2.5965, -1.0452,  0.5284, -0.3195],
        [-1.2631, -0.1298, -0.9593,  0.7437],
        [ 0.4211,  1.2600,  1.5007, -0.1720]])
dglsp.sddmm(A, X1, X2):
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.6188,  0.5866],
        [ 0.0000,  0.0000,  0.0000, -0.6344]])

This operator also supports batched sampled-dense-dense matrix multiplication. For a \(L \times N\) sparse matrix A with non-zero vector values of length \(𝐾\), a \(L \times M \times K\) dense matrix X1 and a \(M \times N \times K\) dense matrix X2, sddmm(A, X1, X2) will be a \(L \times N \times K\) sparse matrix.

[23]:
i = torch.tensor([[1, 1, 2],
                  [2, 3, 3]])
val = torch.tensor([[1., 1.], [2., 2.], [3., 3.]])
A = dglsp.spmatrix(i, val, (3, 4))
print("A:")
print(A.to_dense())

X1 = torch.randn(3, 5, 2)
X2 = torch.randn(5, 4, 2)
print("X1:")
print(X1)
print("X2:")
print(X2)

O = dglsp.sddmm(A, X1, X2)
print("dglsp.sddmm(A, X1, X2):")
print(O.to_dense())
A:
tensor([[[0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.],
         [1., 1.],
         [2., 2.]],

        [[0., 0.],
         [0., 0.],
         [0., 0.],
         [3., 3.]]])
X1:
tensor([[[ 1.8465,  0.7536],
         [-0.1319, -0.4549],
         [-0.8570,  0.2051],
         [-0.6256,  1.0429],
         [ 1.0595, -1.5016]],

        [[-0.6338,  0.2042],
         [-0.0170, -0.6389],
         [ 1.7856, -0.3104],
         [ 1.3565, -0.5085],
         [ 0.1957,  1.3869]],

        [[-1.0676, -0.9831],
         [ 0.9669, -1.6019],
         [ 0.1466,  1.3700],
         [-1.5049,  1.1810],
         [-0.3946,  0.4833]]])
X2:
tensor([[[-1.2415,  0.5989],
         [-0.6700, -0.8538],
         [ 0.0672, -0.3136],
         [ 0.1921, -0.2544]],

        [[-0.8444,  0.2385],
         [-0.9543,  0.7039],
         [-0.3139, -0.3555],
         [-0.4176,  2.9281]],

        [[-1.0143,  0.8043],
         [ 1.7255, -0.2926],
         [-1.2465,  0.8661],
         [ 0.4121, -0.8509]],

        [[-1.3338,  0.4926],
         [-1.1459, -0.3358],
         [-0.3888,  0.9792],
         [-0.8571,  0.2947]],

        [[ 0.6581,  0.5601],
         [ 2.8436,  0.1624],
         [-1.1534,  0.3686],
         [-0.2494,  1.5919]]])
dglsp.sddmm(A, X1, X2):
tensor([[[  0.0000,   0.0000],
         [  0.0000,   0.0000],
         [  0.0000,   0.0000],
         [  0.0000,   0.0000]],

        [[  0.0000,   0.0000],
         [  0.0000,   0.0000],
         [ -3.0162,  -0.0925],
         [ -1.1804,   0.7989]],

        [[  0.0000,   0.0000],
         [  0.0000,   0.0000],
         [  0.0000,   0.0000],
         [  2.5194, -13.4661]]])

Non-linear activation functions

Element-wise functions

Most activation functions are element-wise and can be further grouped into two categories:

Sparse-preserving functions such as sin(), tanh(), sigmoid(), relu(), etc. You can directly apply them on the val tensor of the sparse matrix and then recreate a new matrix of the same sparsity using val_like.

[24]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.randn(4)
A = dglsp.spmatrix(i, val)
print(A.to_dense())

print("Apply tanh.")
A_new = dglsp.val_like(A, torch.tanh(A.val))
print(A_new.to_dense())
tensor([[ 0.0000,  0.5613,  0.0000],
        [-0.1622,  0.0000,  0.0549],
        [-1.1691,  0.0000,  0.0000]])
Apply tanh.
tensor([[ 0.0000,  0.5089,  0.0000],
        [-0.1608,  0.0000,  0.0548],
        [-0.8240,  0.0000,  0.0000]])

Non-sparse-preserving functions such as exp(), cos(), etc. You can first convert the sparse matrix to dense before applying the functions.

[25]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.randn(4)
A = dglsp.spmatrix(i, val)
print(A.to_dense())

print("Apply exp.")
A_new = A.to_dense().exp()
print(A_new)
tensor([[ 0.0000,  2.3190,  0.0000],
        [ 0.3391,  0.0000, -0.0727],
        [ 0.7010,  0.0000,  0.0000]])
Apply exp.
tensor([[ 1.0000, 10.1654,  1.0000],
        [ 1.4037,  1.0000,  0.9299],
        [ 2.0157,  1.0000,  1.0000]])

Softmax

Apply row-wise softmax to the nonzero entries of the sparse matrix.

[26]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.tensor([1., 2., 3., 4.])
A = dglsp.spmatrix(i, val)

print(A.softmax())
print("In dense format:")
print(A.softmax().to_dense())
print("\n")
SparseMatrix(indices=tensor([[0, 1, 1, 2],
                             [1, 0, 2, 0]]),
             values=tensor([1.0000, 0.2689, 0.7311, 1.0000]),
             shape=(3, 3), nnz=4)
In dense format:
tensor([[0.0000, 1.0000, 0.0000],
        [0.2689, 0.0000, 0.7311],
        [1.0000, 0.0000, 0.0000]])


Exercise #1

Let’s test what you’ve learned. Feel free to |Open In Colab|.

Given a sparse symmetrical adjacency matrix \(A\), calculate its symmetrically normalized adjacency matrix:

\[norm = \bar{D}^{-\frac{1}{2}}\bar{A}\bar{D}^{-\frac{1}{2}}\]

Where \(\bar{A} = A + I\), \(I\) is the identity matrix, and \(\bar{D}\) is the diagonal node degree matrix of \(\bar{A}\).

[27]:
i = torch.tensor([[0, 0, 1, 1, 2, 2, 3],
                  [1, 3, 2, 5, 3, 5, 4]])
asym_A = dglsp.spmatrix(i, shape=(6, 6))
# Step 1: create symmetrical adjacency matrix A from asym_A.
# A =

# Step 2: calculate A_hat from A.
# A_hat =

# Step 3: diagonal node degree matrix of A_hat
# D_hat =

# Step 4: calculate the norm from D_hat and A_hat.
# norm =

Exercise #2

Let’s implement a simplified version of the Graph Attention Network (GAT) layer.

A GAT layer has two inputs: the adjacency matrix \(A\) and the node input features \(X\). The idea of GAT layer is to update each node’s representation with a weighted average of the node’s own representation and its neighbors’ representations. In particular, when computing the output for node \(i\), the GAT layer does the following: 1. Compute the scores \(S_{ij}\) representing the attention logit from neighbor \(j\) to node \(i\). \(S_{ij}\) is a function of \(i\) and \(j\)’s input features \(X_i\) and \(X_j\):

\[S_{ij} = LeakyReLU(X_i^\top v_1 + X_j^\top v_2)\]

, where \(v_1\) and \(v_2\) are trainable vectors. 2. Compute a softmax attention \(R_{ij} = \exp S_{ij} / \left( \sum_{j' \in \mathcal{N}_i} s_{ij'} \right)\), where \(\mathcal{N}_j\) means the neighbors of \(j\). This means that \(R\) is a row-wise softmax attention of \(S\). 3. Compute the weighted average \(H_i = \sum_{j' : j' \in \mathcal{N}_i} R_{j'} X_{j'} W\), where \(W\) is a trainable matrix.

The following code defined all the parameters you need but only completes step 1. Could you implement step 2 and step 3?

[28]:
import torch.nn as nn
import torch.nn.functional as F

class SimplifiedGAT(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()

        self.W = nn.Parameter(torch.randn(in_size, out_size))
        self.v1 = nn.Parameter(torch.randn(in_size))
        self.v2 = nn.Parameter(torch.randn(in_size))

    def forward(self, A, X):
        # A: A sparse matrix with size (N, N).  A[i, j] represent the edge from j to i.
        # X: A dense matrix with size (N, D)
        # Step 1: compute S[i, j]
        Xv1 = X @ self.v1
        Xv2 = X @ self.v2
        s = F.leaky_relu(Xv1[A.col] + Xv2[A.row])
        S = dglsp.val_like(A, s)

        # Step 2: compute R[i, j] which is the row-wise attention of $S$.
        # EXERCISE: replace the statement below.
        R = S

        # Step 3: compute H.
        # EXERCISE: replace the statement below.
        H = X

        return H
[29]:
# Test:
# Let's use the symmetric A created above.
X = torch.randn(6, 20)
module = SimplifiedGAT(20, 10)
Y = module(A, X)