Graph Convolutional Network

Author: Qi Huang, Minjie Wang, Yu Gai, Quan Gan, Zheng Zhang

Warning

The tutorial aims at gaining insights into the paper, with code as a mean of explanation. The implementation thus is NOT optimized for running efficiency. For recommended implementation, please refer to the official examples.

This is a gentle introduction of using DGL to implement Graph Convolutional Networks (Kipf & Welling et al., Semi-Supervised Classification with Graph Convolutional Networks). We explain what is under the hood of the GraphConv module. The reader is expected to learn how to define a new GNN layer using DGL’s message passing APIs.

Model Overview

GCN from the perspective of message passing

We describe a layer of graph convolutional neural network from a message passing perspective; the math can be found here. It boils down to the following step, for each node \(u\):

1) Aggregate neighbors’ representations \(h_{v}\) to produce an intermediate representation \(\hat{h}_u\). 2) Transform the aggregated representation \(\hat{h}_{u}\) with a linear projection followed by a non-linearity: \(h_{u} = f(W_{u} \hat{h}_u)\).

We will implement step 1 with DGL message passing, and step 2 by PyTorch nn.Module.

GCN implementation with DGL

We first define the message and reduce function as usual. Since the aggregation on a node \(u\) only involves summing over the neighbors’ representations \(h_v\), we can simply use builtin functions:

import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

gcn_msg = fn.copy_u(u='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')

We then proceed to define the GCNLayer module. A GCNLayer essentially performs message passing on all the nodes then applies a fully-connected layer.

Note

This is showing how to implement a GCN from scratch. DGL provides a more efficient builtin GCN layer module.

class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)

    def forward(self, g, feature):
        # Creating a local scope so that all the stored ndata and edata
        # (such as the `'h'` ndata below) are automatically popped out
        # when the scope exits.
        with g.local_scope():
            g.ndata['h'] = feature
            g.update_all(gcn_msg, gcn_reduce)
            h = g.ndata['h']
            return self.linear(h)

The forward function is essentially the same as any other commonly seen NNs model in PyTorch. We can initialize GCN like any nn.Module. For example, let’s define a simple neural network consisting of two GCN layers. Suppose we are training the classifier for the cora dataset (the input feature size is 1433 and the number of classes is 7). The last GCN layer computes node embeddings, so the last layer in general does not apply activation.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer1 = GCNLayer(1433, 16)
        self.layer2 = GCNLayer(16, 7)

    def forward(self, g, features):
        x = F.relu(self.layer1(g, features))
        x = self.layer2(g, x)
        return x
net = Net()
print(net)

Out:

Net(
  (layer1): GCNLayer(
    (linear): Linear(in_features=1433, out_features=16, bias=True)
  )
  (layer2): GCNLayer(
    (linear): Linear(in_features=16, out_features=7, bias=True)
  )
)

We load the cora dataset using DGL’s built-in data module.

from dgl.data import CoraGraphDataset
def load_cora_data():
    dataset = CoraGraphDataset()
    g = dataset[0]
    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    test_mask = g.ndata['test_mask']
    return g, features, labels, train_mask, test_mask

When a model is trained, we can use the following method to evaluate the performance of the model on the test dataset:

def evaluate(model, g, features, labels, mask):
    model.eval()
    with th.no_grad():
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = th.max(logits, dim=1)
        correct = th.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

We then train the network as follows:

import time
import numpy as np
g, features, labels, train_mask, test_mask = load_cora_data()
# Add edges between each node and itself to preserve old node representations
g.add_edges(g.nodes(), g.nodes())
optimizer = th.optim.Adam(net.parameters(), lr=1e-2)
dur = []
for epoch in range(50):
    if epoch >=3:
        t0 = time.time()

    net.train()
    logits = net(g, features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[train_mask], labels[train_mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >=3:
        dur.append(time.time() - t0)

    acc = evaluate(net, g, features, labels, test_mask)
    print("Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
            epoch, loss.item(), acc, np.mean(dur)))

Out:

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/torch/autocast_mode.py:141: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
  warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3257: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
Epoch 00000 | Loss 1.9731 | Test Acc 0.3550 | Time(s) nan
Epoch 00001 | Loss 1.8475 | Test Acc 0.4060 | Time(s) nan
Epoch 00002 | Loss 1.7285 | Test Acc 0.5510 | Time(s) nan
Epoch 00003 | Loss 1.6094 | Test Acc 0.6480 | Time(s) 0.0139
Epoch 00004 | Loss 1.4938 | Test Acc 0.7100 | Time(s) 0.0139
Epoch 00005 | Loss 1.3797 | Test Acc 0.7280 | Time(s) 0.0150
Epoch 00006 | Loss 1.2679 | Test Acc 0.7430 | Time(s) 0.0148
Epoch 00007 | Loss 1.1637 | Test Acc 0.7370 | Time(s) 0.0147
Epoch 00008 | Loss 1.0670 | Test Acc 0.7410 | Time(s) 0.0152
Epoch 00009 | Loss 0.9761 | Test Acc 0.7430 | Time(s) 0.0155
Epoch 00010 | Loss 0.8908 | Test Acc 0.7450 | Time(s) 0.0158
Epoch 00011 | Loss 0.8130 | Test Acc 0.7480 | Time(s) 0.0157
Epoch 00012 | Loss 0.7421 | Test Acc 0.7520 | Time(s) 0.0158
Epoch 00013 | Loss 0.6772 | Test Acc 0.7570 | Time(s) 0.0159
Epoch 00014 | Loss 0.6179 | Test Acc 0.7590 | Time(s) 0.0160
Epoch 00015 | Loss 0.5635 | Test Acc 0.7610 | Time(s) 0.0160
Epoch 00016 | Loss 0.5133 | Test Acc 0.7610 | Time(s) 0.0162
Epoch 00017 | Loss 0.4668 | Test Acc 0.7590 | Time(s) 0.0160
Epoch 00018 | Loss 0.4240 | Test Acc 0.7620 | Time(s) 0.0161
Epoch 00019 | Loss 0.3855 | Test Acc 0.7630 | Time(s) 0.0160
Epoch 00020 | Loss 0.3505 | Test Acc 0.7590 | Time(s) 0.0159
Epoch 00021 | Loss 0.3186 | Test Acc 0.7590 | Time(s) 0.0159
Epoch 00022 | Loss 0.2897 | Test Acc 0.7580 | Time(s) 0.0160
Epoch 00023 | Loss 0.2636 | Test Acc 0.7530 | Time(s) 0.0159
Epoch 00024 | Loss 0.2397 | Test Acc 0.7560 | Time(s) 0.0158
Epoch 00025 | Loss 0.2182 | Test Acc 0.7540 | Time(s) 0.0159
Epoch 00026 | Loss 0.1987 | Test Acc 0.7550 | Time(s) 0.0159
Epoch 00027 | Loss 0.1807 | Test Acc 0.7540 | Time(s) 0.0159
Epoch 00028 | Loss 0.1643 | Test Acc 0.7540 | Time(s) 0.0158
Epoch 00029 | Loss 0.1493 | Test Acc 0.7540 | Time(s) 0.0159
Epoch 00030 | Loss 0.1356 | Test Acc 0.7530 | Time(s) 0.0158
Epoch 00031 | Loss 0.1233 | Test Acc 0.7540 | Time(s) 0.0158
Epoch 00032 | Loss 0.1120 | Test Acc 0.7560 | Time(s) 0.0158
Epoch 00033 | Loss 0.1019 | Test Acc 0.7560 | Time(s) 0.0157
Epoch 00034 | Loss 0.0928 | Test Acc 0.7570 | Time(s) 0.0158
Epoch 00035 | Loss 0.0845 | Test Acc 0.7580 | Time(s) 0.0158
Epoch 00036 | Loss 0.0771 | Test Acc 0.7590 | Time(s) 0.0159
Epoch 00037 | Loss 0.0704 | Test Acc 0.7580 | Time(s) 0.0158
Epoch 00038 | Loss 0.0644 | Test Acc 0.7570 | Time(s) 0.0159
Epoch 00039 | Loss 0.0590 | Test Acc 0.7580 | Time(s) 0.0158
Epoch 00040 | Loss 0.0542 | Test Acc 0.7580 | Time(s) 0.0158
Epoch 00041 | Loss 0.0498 | Test Acc 0.7560 | Time(s) 0.0158
Epoch 00042 | Loss 0.0458 | Test Acc 0.7570 | Time(s) 0.0158
Epoch 00043 | Loss 0.0422 | Test Acc 0.7580 | Time(s) 0.0157
Epoch 00044 | Loss 0.0390 | Test Acc 0.7550 | Time(s) 0.0158
Epoch 00045 | Loss 0.0361 | Test Acc 0.7530 | Time(s) 0.0157
Epoch 00046 | Loss 0.0334 | Test Acc 0.7520 | Time(s) 0.0157
Epoch 00047 | Loss 0.0310 | Test Acc 0.7510 | Time(s) 0.0157
Epoch 00048 | Loss 0.0289 | Test Acc 0.7510 | Time(s) 0.0157
Epoch 00049 | Loss 0.0269 | Test Acc 0.7510 | Time(s) 0.0157

GCN in one formula

Mathematically, the GCN model follows this formula:

\(H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})\)

Here, \(H^{(l)}\) denotes the \(l^{th}\) layer in the network, \(\sigma\) is the non-linearity, and \(W\) is the weight matrix for this layer. \(\tilde{D}\) and \(\tilde{A}\) are separately the degree and adjacency matrices for the graph. With the superscript ~, we are referring to the variant where we add additional edges between each node and itself to preserve its old representation in graph convolutions. The shape of the input \(H^{(0)}\) is \(N \times D\), where \(N\) is the number of nodes and \(D\) is the number of input features. We can chain up multiple layers as such to produce a node-level representation output with shape \(N \times F\), where \(F\) is the dimension of the output node feature vector.

The equation can be efficiently implemented using sparse matrix multiplication kernels (such as Kipf’s pygcn code). The above DGL implementation in fact has already used this trick due to the use of builtin functions.

Note that the tutorial code implements a simplified version of GCN where we replace \(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}\) with \(\tilde{A}\). For a full implementation, see our example here.

Total running time of the script: ( 0 minutes 1.470 seconds)

Gallery generated by Sphinx-Gallery