Building a Graph Convolutional Network Using Sparse Matrices

This tutorial illustrates step-by-step how to write and train a Graph Convolutional Network (Kipf et al. (2017)) using DGL’s sparse matrix APIs.

Open In Colab GitHub

[ ]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

# Uncomment below to install required packages. If the CUDA version is not 11.6,
# check the https://www.dgl.ai/pages/start.html to find the supported CUDA
# version and corresponding command to install DGL.
#!pip install dgl -f https://data.dgl.ai/wheels/cu116/repo.html > /dev/null

try:
    import dgl
    installed = True
except ImportError:
    installed = False
print("DGL installed!" if installed else "DGL not found!")
DGL installed!

Graph Convolutional Layer

Mathematically, the graph convolutional layer is defined as:

\[f(X^{(l)}, A) = \sigma(\hat{D}^{-\frac{1}{2}}\hat{A}\hat{D}^{-\frac{1}{2}}X^{(l)}W^{(l)})\]

with \(\hat{A} = A + I\), where \(A\) denotes the adjacency matrix and \(I\) denotes the identity matrix, \(\hat{D}\) refers to the diagonal node degree matrix of \(\hat{A}\) and \(W^{(l)}\) denotes a trainable weight matrix. \(\sigma\) refers to a non-linear activation (e.g. relu).

The code below shows how to implement it using the dgl.sparse package. The core operations are:

  • dgl.sparse.identity creates the identity matrix \(I\).

  • The augmented adjacency matrix \(\hat{A}\) is then computed by adding the identity matrix to the adjacency matrix \(A\).

  • A_hat.sum(0) aggregates the augmented adjacency matrix \(\hat{A}\) along the first dimension which gives the degree vector of the augmented graph. The diagonal degree matrix \(\hat{D}\) is then created by dgl.sparse.diag.

  • Compute \(\hat{D}^{-\frac{1}{2}}\).

  • D_hat_invsqrt @ A_hat @ D_hat_invsqrt computes the convolution matrix which is then multiplied by the linearly transformed node features.

[ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl.sparse as dglsp

class GCNLayer(nn.Module):
    def __init__(self, in_size, out_size):
        super(GCNLayer, self).__init__()
        self.W = nn.Linear(in_size, out_size)

    def forward(self, A, X):
        ########################################################################
        # (HIGHLIGHT) Compute the symmetrically normalized adjacency matrix with
        # Sparse Matrix API
        ########################################################################
        I = dglsp.identity(A.shape)
        A_hat = A + I
        D_hat = dglsp.diag(A_hat.sum(0))
        D_hat_invsqrt = D_hat ** -0.5
        return D_hat_invsqrt @ A_hat @ D_hat_invsqrt @ self.W(X)

A Graph Convolutional Network is then defined by stacking this layer.

[ ]:
# Create a GCN with the GCN layer.
class GCN(nn.Module):
    def __init__(self, in_size, out_size, hidden_size):
        super(GCN, self).__init__()
        self.conv1 = GCNLayer(in_size, hidden_size)
        self.conv2 = GCNLayer(hidden_size, out_size)

    def forward(self, A, X):
        X = self.conv1(A, X)
        X = F.relu(X)
        return self.conv2(A, X)

Training the GCN

We then train the GCN model on the Cora dataset for node classification. Note that since the model expects an adjacency matrix as the first argument, we first construct the adjacency matrix from the graph using the dgl.sparse.from_coo API which returns a DGL SparseMatrix object.

[ ]:
def evaluate(g, pred):
    label = g.ndata["label"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]

    # Compute accuracy on validation/test set.
    val_acc = (pred[val_mask] == label[val_mask]).float().mean()
    test_acc = (pred[test_mask] == label[test_mask]).float().mean()
    return val_acc, test_acc

def train(model, g):
    features = g.ndata["feat"]
    label = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    loss_fcn = nn.CrossEntropyLoss()

    # Preprocess to get the adjacency matrix of the graph.
    indices = torch.stack(g.edges())
    N = g.num_nodes()
    A = dglsp.spmatrix(indices, shape=(N, N))

    for epoch in range(100):
        model.train()

        # Forward.
        logits = model(A, features)

        # Compute loss with nodes in the training set.
        loss = loss_fcn(logits[train_mask], label[train_mask])

        # Backward.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute prediction.
        pred = logits.argmax(dim=1)

        # Evaluate the prediction.
        val_acc, test_acc = evaluate(g, pred)
        if epoch % 5 == 0:
            print(
                f"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}"
                f", test acc: {test_acc:.3f}"
            )


# Load graph from the existing dataset.
dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

# Create model.
feature = g.ndata['feat']
in_size = feature.shape[1]
out_size = dataset.num_classes
gcn_model = GCN(in_size, out_size, 16)

# Kick off training.
train(gcn_model, g)
Downloading /root/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...
Extracting file to /root/.dgl/cora_v2
Finished data loading and preprocessing.
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.
In epoch 0, loss: 1.954, val acc: 0.114, test acc: 0.103
In epoch 5, loss: 1.921, val acc: 0.158, test acc: 0.147
In epoch 10, loss: 1.878, val acc: 0.288, test acc: 0.283
In epoch 15, loss: 1.822, val acc: 0.344, test acc: 0.353
In epoch 20, loss: 1.751, val acc: 0.388, test acc: 0.389
In epoch 25, loss: 1.663, val acc: 0.406, test acc: 0.410
In epoch 30, loss: 1.562, val acc: 0.472, test acc: 0.481
In epoch 35, loss: 1.450, val acc: 0.558, test acc: 0.573
In epoch 40, loss: 1.333, val acc: 0.636, test acc: 0.641
In epoch 45, loss: 1.216, val acc: 0.684, test acc: 0.683
In epoch 50, loss: 1.102, val acc: 0.726, test acc: 0.713
In epoch 55, loss: 0.996, val acc: 0.740, test acc: 0.740
In epoch 60, loss: 0.899, val acc: 0.754, test acc: 0.760
In epoch 65, loss: 0.813, val acc: 0.762, test acc: 0.771
In epoch 70, loss: 0.737, val acc: 0.768, test acc: 0.781
In epoch 75, loss: 0.671, val acc: 0.776, test acc: 0.786
In epoch 80, loss: 0.614, val acc: 0.784, test acc: 0.790
In epoch 85, loss: 0.566, val acc: 0.780, test acc: 0.788
In epoch 90, loss: 0.524, val acc: 0.780, test acc: 0.791
In epoch 95, loss: 0.489, val acc: 0.772, test acc: 0.795

Check out the full example script here.