Graph Convolutional Network

Author: Qi Huang, Minjie Wang, Yu Gai, Quan Gan, Zheng Zhang

This is a gentle introduction of using DGL to implement Graph Convolutional Networks (Kipf & Welling et al., Semi-Supervised Classification with Graph Convolutional Networks). We build upon the earlier tutorial on DGLGraph and demonstrate how DGL combines graph with deep neural network and learn structural representations.

Model Overview

GCN from the perspective of message passing

We describe a layer of graph convolutional neural network from a message passing perspective; the math can be found here. It boils down to the following step, for each node \(u\):

1) Aggregate neighbors’ representations \(h_{v}\) to produce an intermediate representation \(\hat{h}_u\). 2) Transform the aggregated representation \(\hat{h}_{u}\) with a linear projection followed by a non-linearity: \(h_{u} = f(W_{u} \hat{h}_u)\).

We will implement step 1 with DGL message passing, and step 2 with the apply_nodes method, whose node UDF will be a PyTorch nn.Module.

GCN implementation with DGL

We first define the message and reduce function as usual. Since the aggregation on a node \(u\) only involves summing over the neighbors’ representations \(h_v\), we can simply use builtin functions:

import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')

We then define the node UDF for apply_nodes, which is a fully-connected layer:

class NodeApplyModule(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        h = self.activation(h)
        return {'h' : h}

We then proceed to define the GCN module. A GCN layer essentially performs message passing on all the nodes then applies the NodeApplyModule. Note that we omitted the dropout in the paper for simplicity.

class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        g.ndata['h'] = feature
        g.update_all(gcn_msg, gcn_reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

The forward function is essentially the same as any other commonly seen NNs model in PyTorch. We can initialize GCN like any nn.Module. For example, let’s define a simple neural network consisting of two GCN layers. Suppose we are training the classifier for the cora dataset (the input feature size is 1433 and the number of classes is 7).

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.gcn1 = GCN(1433, 16, F.relu)
        self.gcn2 = GCN(16, 7, F.relu)

    def forward(self, g, features):
        x = self.gcn1(g, features)
        x = self.gcn2(g, x)
        return x
net = Net()
print(net)

Out:

Net(
  (gcn1): GCN(
    (apply_mod): NodeApplyModule(
      (linear): Linear(in_features=1433, out_features=16, bias=True)
    )
  )
  (gcn2): GCN(
    (apply_mod): NodeApplyModule(
      (linear): Linear(in_features=16, out_features=7, bias=True)
    )
  )
)

We load the cora dataset using DGL’s built-in data module.

from dgl.data import citation_graph as citegrh
def load_cora_data():
    data = citegrh.load_cora()
    features = th.FloatTensor(data.features)
    labels = th.LongTensor(data.labels)
    mask = th.ByteTensor(data.train_mask)
    g = DGLGraph(data.graph)
    return g, features, labels, mask

We then train the network as follows:

import time
import numpy as np
g, features, labels, mask = load_cora_data()
optimizer = th.optim.Adam(net.parameters(), lr=1e-3)
dur = []
for epoch in range(30):
    if epoch >=3:
        t0 = time.time()

    logits = net(g, features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >=3:
        dur.append(time.time() - t0)

    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
            epoch, loss.item(), np.mean(dur)))

Out:

Epoch 00000 | Loss 1.9175 | Time(s) nan
Epoch 00001 | Loss 1.9068 | Time(s) nan
Epoch 00002 | Loss 1.8958 | Time(s) nan
Epoch 00003 | Loss 1.8847 | Time(s) 0.2614
Epoch 00004 | Loss 1.8731 | Time(s) 0.2599
Epoch 00005 | Loss 1.8615 | Time(s) 0.2601
Epoch 00006 | Loss 1.8499 | Time(s) 0.2610
Epoch 00007 | Loss 1.8377 | Time(s) 0.2610
Epoch 00008 | Loss 1.8258 | Time(s) 0.2605
Epoch 00009 | Loss 1.8135 | Time(s) 0.2604
Epoch 00010 | Loss 1.8011 | Time(s) 0.2604
Epoch 00011 | Loss 1.7887 | Time(s) 0.2603
Epoch 00012 | Loss 1.7764 | Time(s) 0.2604
Epoch 00013 | Loss 1.7641 | Time(s) 0.2603
Epoch 00014 | Loss 1.7519 | Time(s) 0.2606
Epoch 00015 | Loss 1.7397 | Time(s) 0.2607
Epoch 00016 | Loss 1.7274 | Time(s) 0.2606
Epoch 00017 | Loss 1.7155 | Time(s) 0.2605
Epoch 00018 | Loss 1.7039 | Time(s) 0.2606
Epoch 00019 | Loss 1.6925 | Time(s) 0.2604
Epoch 00020 | Loss 1.6814 | Time(s) 0.2604
Epoch 00021 | Loss 1.6705 | Time(s) 0.2604
Epoch 00022 | Loss 1.6601 | Time(s) 0.2604
Epoch 00023 | Loss 1.6501 | Time(s) 0.2606
Epoch 00024 | Loss 1.6402 | Time(s) 0.2606
Epoch 00025 | Loss 1.6305 | Time(s) 0.2605
Epoch 00026 | Loss 1.6209 | Time(s) 0.2604
Epoch 00027 | Loss 1.6115 | Time(s) 0.2605
Epoch 00028 | Loss 1.6023 | Time(s) 0.2605
Epoch 00029 | Loss 1.5932 | Time(s) 0.2605

GCN in one formula

Mathematically, the GCN model follows this formula:

\(H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})\)

Here, \(H^{(l)}\) denotes the \(l^{th}\) layer in the network, \(\sigma\) is the non-linearity, and \(W\) is the weight matrix for this layer. \(D\) and \(A\), as commonly seen, represent degree matrix and adjacency matrix, respectively. The ~ is a renormalization trick in which we add a self-connection to each node of the graph, and build the corresponding degree and adjacency matrix. The shape of the input \(H^{(0)}\) is \(N \times D\), where \(N\) is the number of nodes and \(D\) is the number of input features. We can chain up multiple layers as such to produce a node-level representation output with shape :math`N times F`, where \(F\) is the dimension of the output node feature vector.

The equation can be efficiently implemented using sparse matrix multiplication kernels (such as Kipf’s pygcn code). The above DGL implementation in fact has already used this trick due to the use of builtin functions. To understand what is under the hood, please read our tutorial on PageRank.

Total running time of the script: ( 0 minutes 11.890 seconds)

Gallery generated by Sphinx-Gallery