"""
.. currentmodule:: dgl
Batched Graph Classification with DGL
=====================================
**Author**: `Mufei Li `_,
`Minjie Wang `_,
`Zheng Zhang `_.
Graph classification is an important problem
with applications across many fields -- bioinformatics, chemoinformatics, social
network analysis, urban computing and cyber-security. Applying graph neural
networks to this problem has been a popular approach recently (
`Ying et al., 2018 `_,
`Cangea et al., 2018 `_,
`Knyazev et al., 2018 `_,
`Bianchi et al., 2019 `_,
`Liao et al., 2019 `_,
`Gao et al., 2019 `_).
This tutorial demonstrates:
* batching multiple graphs of variable size and shape with DGL
* training a graph neural network for a simple graph classification task
"""
###############################################################################
# Simple Graph Classification Task
# --------------------------------
# In this tutorial, we will learn how to perform batched graph classification
# with dgl via a toy example of classifying 8 types of regular graphs as below:
#
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/dataset_overview.png
# :align: center
#
# We implement a synthetic dataset :class:`data.MiniGCDataset` in DGL. The dataset has 8
# different types of graphs and each class has the same number of graph samples.
from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt
import networkx as nx
# A dataset with 80 samples, each graph is
# of size [10, 20]
dataset = MiniGCDataset(80, 10, 20)
graph, label = dataset[0]
fig, ax = plt.subplots()
nx.draw(graph.to_networkx(), ax=ax)
ax.set_title('Class: {:d}'.format(label))
plt.show()
###############################################################################
# Form a graph mini-batch
# -----------------------
# To train neural networks more efficiently, a common practice is to **batch**
# multiple samples together to form a mini-batch. Batching fixed-shaped tensor
# inputs is quite easy (for example, batching two images of size :math:`28\times 28`
# gives a tensor of shape :math:`2\times 28\times 28`). By contrast, batching graph inputs
# has two challenges:
#
# * Graphs are sparse.
# * Graphs can have various length (e.g. number of nodes and edges).
#
# To address this, DGL provides a :func:`dgl.batch` API. It leverages the trick that
# a batch of graphs can be viewed as a large graph that have many disjoint
# connected components. Below is a visualization that gives the general idea:
#
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/batch.png
# :width: 400pt
# :align: center
#
# We define the following ``collate`` function to form a mini-batch from a given
# list of graph and label pairs.
import dgl
def collate(samples):
# The input `samples` is a list of pairs
# (graph, label).
graphs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
return batched_graph, torch.tensor(labels)
###############################################################################
# The return type of :func:`dgl.batch` is still a graph (similar to the fact that
# a batch of tensors is still a tensor). This means that any code that works
# for one graph immediately works for a batch of graphs. More importantly,
# since DGL processes messages on all nodes and edges in parallel, this greatly
# improves efficiency.
#
# Graph Classifier
# ----------------
# The graph classification can be proceeded as follows:
#
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/graph_classifier.png
#
# From a batch of graphs, we first perform message passing/graph convolution
# for nodes to "communicate" with others. After message passing, we compute a
# tensor for graph representation from node (and edge) attributes. This step may
# be called "readout/aggregation" interchangeably. Finally, the graph
# representations can be fed into a classifier :math:`g` to predict the graph labels.
#
# Graph Convolution
# -----------------
# Our graph convolution operation is basically the same as that for GCN (checkout our
# `tutorial `_). The only difference is
# that we replace :math:`h_{v}^{(l+1)} = \text{ReLU}\left(b^{(l)}+\sum_{u\in\mathcal{N}(v)}h_{u}^{(l)}W^{(l)}\right)` by
# :math:`h_{v}^{(l+1)} = \text{ReLU}\left(b^{(l)}+\frac{1}{|\mathcal{N}(v)|}\sum_{u\in\mathcal{N}(v)}h_{u}^{(l)}W^{(l)}\right)`.
# The replacement of summation by average is to balance nodes with different
# degrees, which gives a better performance for this experiment.
#
# Note that the self edges added in the dataset initialization allows us to
# include the original node feature :math:`h_{v}^{(l)}` when taking the average.
import dgl.function as fn
import torch
import torch.nn as nn
# Sends a message of node feature h.
msg = fn.copy_src(src='h', out='m')
def reduce(nodes):
"""Take an average over all neighbor node features hu and use it to
overwrite the original node feature."""
accum = torch.mean(nodes.mailbox['m'], 1)
return {'h': accum}
class NodeApplyModule(nn.Module):
"""Update the node feature hv with ReLU(Whv+b)."""
def __init__(self, in_feats, out_feats, activation):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node.data['h'])
h = self.activation(h)
return {'h' : h}
class GCN(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(GCN, self).__init__()
self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)
def forward(self, g, feature):
# Initialize the node features with h.
g.ndata['h'] = feature
g.update_all(msg, reduce)
g.apply_nodes(func=self.apply_mod)
return g.ndata.pop('h')
###############################################################################
# Readout and Classification
# --------------------------
# For this demonstration, we consider initial node features to be their degrees.
# After two rounds of graph convolution, we perform a graph readout by averaging
# over all node features for each graph in the batch
#
# .. math::
#
# h_g=\frac{1}{|\mathcal{V}|}\sum_{v\in\mathcal{V}}h_{v}
#
# In DGL, :func:`dgl.mean_nodes` handles this task for a batch of
# graphs with variable size. We then feed our graph representations into a
# classifier with one linear layer to obtain pre-softmax logits.
import torch.nn.functional as F
class Classifier(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(Classifier, self).__init__()
self.layers = nn.ModuleList([
GCN(in_dim, hidden_dim, F.relu),
GCN(hidden_dim, hidden_dim, F.relu)])
self.classify = nn.Linear(hidden_dim, n_classes)
def forward(self, g):
# For undirected graphs, in_degree is the same as
# out_degree.
h = g.in_degrees().view(-1, 1).float()
for conv in self.layers:
h = conv(g, h)
g.ndata['h'] = h
hg = dgl.mean_nodes(g, 'h')
return self.classify(hg)
###############################################################################
# Setup and Training
# ------------------
# We create a synthetic dataset of :math:`400` graphs with :math:`10` ~
# :math:`20` nodes. :math:`320` graphs constitute a training set and
# :math:`80` graphs constitute a test set.
import torch.optim as optim
from torch.utils.data import DataLoader
# Create training and test sets.
trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20)
# Use PyTorch's DataLoader and the collate function
# defined before.
data_loader = DataLoader(trainset, batch_size=32, shuffle=True,
collate_fn=collate)
# Create model
model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
epoch_losses = []
for epoch in range(80):
epoch_loss = 0
for iter, (bg, label) in enumerate(data_loader):
prediction = model(bg)
loss = loss_func(prediction, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_loss /= (iter + 1)
print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
epoch_losses.append(epoch_loss)
###############################################################################
# The learning curve of a run is presented below:
plt.title('cross entropy averaged over minibatches')
plt.plot(epoch_losses)
plt.show()
###############################################################################
# The trained model is evaluated on the test set created. Note that for deployment
# of the tutorial, we restrict our running time and you are likely to get a higher
# accuracy (:math:`80` % ~ :math:`90` %) than the ones printed below.
model.eval()
# Convert a list of tuples to two lists
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
probs_Y = torch.softmax(model(test_bg), 1)
sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
(test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
(test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))
###############################################################################
# Below is an animation where we plot graphs with the probability a trained model
# assigns its ground truth label to it:
#
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/test_eval4.gif
#
# To understand the node/graph representations a trained model learnt,
# we use `t-SNE, `_ for dimensionality reduction
# and visualization.
#
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/tsne_node2.png
# :align: center
#
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/tsne_graph2.png
# :align: center
#
# The two small figures on the top separately visualize node representations after :math:`1`,
# :math:`2` layers of graph convolution and the figure on the bottom visualizes
# the pre-softmax logits for graphs as graph representations.
#
# While the visualization does suggest some clustering effects of the node features,
# it is expected not to be a perfect result as node degrees are deterministic for
# our node features. Meanwhile, the graph features are way better separated.
#
# What's Next?
# ------------
# Graph classification with graph neural networks is still a very young field
# waiting for folks to bring more exciting discoveries! It is not easy as it
# requires mapping different graphs to different embeddings while preserving
# their structural similarity in the embedding space. To learn more about it,
# `"How Powerful Are Graph Neural Networks?" `_
# in ICLR 2019 might be a good starting point.
#
# With regards to more examples on batched graph processing, see
#
# * our tutorials on `Tree LSTM `_ and `Deep Generative Models of Graphs `_
# * an example implementation of `Junction Tree VAE `_