Note
Click here to download the full example code
Write your own GNN module¶
Sometimes, your model goes beyond simply stacking existing GNN modules. For example, you would like to invent a new way of aggregating neighbor information by considering node importance or edge weights.
By the end of this tutorial you will be able to
Understand DGL’s message passing APIs.
Implement GraphSAGE convolution module by your own.
This tutorial assumes that you already know the basics of training a GNN for node classification.
(Time estimate: 10 minutes)
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
Message passing and GNNs¶
DGL follows the message passing paradigm inspired by the Message Passing Neural Network proposed by Gilmer et al. Essentially, they found many GNN models can fit into the following framework:
where DGL calls \(M^{(l)}\) the message function, \(\sum\) the reduce function and \(U^{(l)}\) the update function. Note that \(\sum\) here can represent any function and is not necessarily a summation.
For example, the GraphSAGE convolution (Hamilton et al., 2017) takes the following mathematical form:
You can see that message passing is directional: the message sent from one node \(u\) to other node \(v\) is not necessarily the same as the other message sent from node \(v\) to node \(u\) in the opposite direction.
Although DGL has builtin support of GraphSAGE via
dgl.nn.SAGEConv
,
here is how you can implement GraphSAGE convolution in DGL by your own.
import dgl.function as fn
class SAGEConv(nn.Module):
"""Graph convolution module used by the GraphSAGE model.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
"""
def __init__(self, in_feat, out_feat):
super(SAGEConv, self).__init__()
# A linear submodule for projecting the input and neighbor feature to the output.
self.linear = nn.Linear(in_feat * 2, out_feat)
def forward(self, g, h):
"""Forward computation
Parameters
----------
g : Graph
The input graph.
h : Tensor
The input node feature.
"""
with g.local_scope():
g.ndata['h'] = h
# update_all is a message passing API.
g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
h_N = g.ndata['h_N']
h_total = torch.cat([h, h_N], dim=1)
return self.linear(h_total)
The central piece in this code is the
g.update_all
function, which gathers and averages the neighbor features. There are
three concepts here:
Message function
fn.copy_u('h', 'm')
that copies the node feature under name'h'
as messages sent to neighbors.Reduce function
fn.mean('m', 'h_N')
that averages all the received messages under name'm'
and saves the result as a new node feature'h_N'
.update_all
tells DGL to trigger the message and reduce functions for all the nodes and edges.
Afterwards, you can stack your own GraphSAGE convolution layers to form a multi-layer GraphSAGE network.
class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats)
self.conv2 = SAGEConv(h_feats, num_classes)
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = F.relu(h)
h = self.conv2(g, h)
return h
Training loop¶
The following code for data loading and training loop is directly copied from the introduction tutorial.
import dgl.data
dataset = dgl.data.CoraGraphDataset()
g = dataset[0]
def train(g, model):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
all_logits = []
best_val_acc = 0
best_test_acc = 0
features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
for e in range(200):
# Forward
logits = model(g, features)
# Compute prediction
pred = logits.argmax(1)
# Compute loss
# Note that we should only compute the losses of the nodes in the training set,
# i.e. with train_mask 1.
loss = F.cross_entropy(logits[train_mask], labels[train_mask])
# Compute accuracy on training/validation/test
train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
test_acc = (pred[test_mask] == labels[test_mask]).float().mean()
# Save the best validation accuracy and the corresponding test accuracy.
if best_val_acc < val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
all_logits.append(logits.detach())
if e % 5 == 0:
print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(
e, loss, val_acc, best_val_acc, test_acc, best_test_acc))
model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)
Out:
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/torch/autocast_mode.py:141: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
In epoch 0, loss: 1.951, val acc: 0.156 (best 0.156), test acc: 0.144 (best 0.144)
In epoch 5, loss: 1.879, val acc: 0.156 (best 0.156), test acc: 0.144 (best 0.144)
In epoch 10, loss: 1.738, val acc: 0.538 (best 0.540), test acc: 0.545 (best 0.532)
In epoch 15, loss: 1.522, val acc: 0.420 (best 0.540), test acc: 0.454 (best 0.532)
In epoch 20, loss: 1.243, val acc: 0.456 (best 0.540), test acc: 0.511 (best 0.532)
In epoch 25, loss: 0.931, val acc: 0.564 (best 0.564), test acc: 0.588 (best 0.588)
In epoch 30, loss: 0.635, val acc: 0.634 (best 0.634), test acc: 0.674 (best 0.674)
In epoch 35, loss: 0.398, val acc: 0.678 (best 0.678), test acc: 0.710 (best 0.710)
In epoch 40, loss: 0.236, val acc: 0.706 (best 0.706), test acc: 0.721 (best 0.721)
In epoch 45, loss: 0.138, val acc: 0.708 (best 0.710), test acc: 0.729 (best 0.726)
In epoch 50, loss: 0.083, val acc: 0.714 (best 0.714), test acc: 0.726 (best 0.726)
In epoch 55, loss: 0.053, val acc: 0.718 (best 0.720), test acc: 0.727 (best 0.726)
In epoch 60, loss: 0.036, val acc: 0.718 (best 0.720), test acc: 0.726 (best 0.726)
In epoch 65, loss: 0.026, val acc: 0.714 (best 0.720), test acc: 0.724 (best 0.726)
In epoch 70, loss: 0.021, val acc: 0.714 (best 0.720), test acc: 0.723 (best 0.726)
In epoch 75, loss: 0.017, val acc: 0.712 (best 0.720), test acc: 0.719 (best 0.726)
In epoch 80, loss: 0.014, val acc: 0.712 (best 0.720), test acc: 0.718 (best 0.726)
In epoch 85, loss: 0.012, val acc: 0.714 (best 0.720), test acc: 0.717 (best 0.726)
In epoch 90, loss: 0.011, val acc: 0.714 (best 0.720), test acc: 0.715 (best 0.726)
In epoch 95, loss: 0.010, val acc: 0.714 (best 0.720), test acc: 0.715 (best 0.726)
In epoch 100, loss: 0.009, val acc: 0.714 (best 0.720), test acc: 0.719 (best 0.726)
In epoch 105, loss: 0.008, val acc: 0.714 (best 0.720), test acc: 0.720 (best 0.726)
In epoch 110, loss: 0.008, val acc: 0.714 (best 0.720), test acc: 0.720 (best 0.726)
In epoch 115, loss: 0.007, val acc: 0.714 (best 0.720), test acc: 0.720 (best 0.726)
In epoch 120, loss: 0.007, val acc: 0.714 (best 0.720), test acc: 0.720 (best 0.726)
In epoch 125, loss: 0.006, val acc: 0.714 (best 0.720), test acc: 0.720 (best 0.726)
In epoch 130, loss: 0.006, val acc: 0.712 (best 0.720), test acc: 0.720 (best 0.726)
In epoch 135, loss: 0.005, val acc: 0.712 (best 0.720), test acc: 0.719 (best 0.726)
In epoch 140, loss: 0.005, val acc: 0.712 (best 0.720), test acc: 0.720 (best 0.726)
In epoch 145, loss: 0.005, val acc: 0.714 (best 0.720), test acc: 0.719 (best 0.726)
In epoch 150, loss: 0.005, val acc: 0.714 (best 0.720), test acc: 0.720 (best 0.726)
In epoch 155, loss: 0.004, val acc: 0.714 (best 0.720), test acc: 0.720 (best 0.726)
In epoch 160, loss: 0.004, val acc: 0.714 (best 0.720), test acc: 0.720 (best 0.726)
In epoch 165, loss: 0.004, val acc: 0.714 (best 0.720), test acc: 0.720 (best 0.726)
In epoch 170, loss: 0.004, val acc: 0.714 (best 0.720), test acc: 0.720 (best 0.726)
In epoch 175, loss: 0.004, val acc: 0.714 (best 0.720), test acc: 0.719 (best 0.726)
In epoch 180, loss: 0.003, val acc: 0.714 (best 0.720), test acc: 0.718 (best 0.726)
In epoch 185, loss: 0.003, val acc: 0.714 (best 0.720), test acc: 0.719 (best 0.726)
In epoch 190, loss: 0.003, val acc: 0.714 (best 0.720), test acc: 0.719 (best 0.726)
In epoch 195, loss: 0.003, val acc: 0.714 (best 0.720), test acc: 0.718 (best 0.726)
More customization¶
In DGL, we provide many built-in message and reduce functions under the
dgl.function
package. You can find more details in the API
doc.
These APIs allow one to quickly implement new graph convolution modules.
For example, the following implements a new SAGEConv
that aggregates
neighbor representations using a weighted average. Note that edata
member can hold edge features which can also take part in message
passing.
class WeightedSAGEConv(nn.Module):
"""Graph convolution module used by the GraphSAGE model with edge weights.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
"""
def __init__(self, in_feat, out_feat):
super(WeightedSAGEConv, self).__init__()
# A linear submodule for projecting the input and neighbor feature to the output.
self.linear = nn.Linear(in_feat * 2, out_feat)
def forward(self, g, h, w):
"""Forward computation
Parameters
----------
g : Graph
The input graph.
h : Tensor
The input node feature.
w : Tensor
The edge weight.
"""
with g.local_scope():
g.ndata['h'] = h
g.edata['w'] = w
g.update_all(message_func=fn.u_mul_e('h', 'w', 'm'), reduce_func=fn.mean('m', 'h_N'))
h_N = g.ndata['h_N']
h_total = torch.cat([h, h_N], dim=1)
return self.linear(h_total)
Because the graph in this dataset does not have edge weights, we
manually assign all edge weights to one in the forward()
function of
the model. You can replace it with your own edge weights.
class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__()
self.conv1 = WeightedSAGEConv(in_feats, h_feats)
self.conv2 = WeightedSAGEConv(h_feats, num_classes)
def forward(self, g, in_feat):
h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
h = F.relu(h)
h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
return h
model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)
Out:
In epoch 0, loss: 1.951, val acc: 0.072 (best 0.072), test acc: 0.091 (best 0.091)
In epoch 5, loss: 1.881, val acc: 0.118 (best 0.118), test acc: 0.127 (best 0.127)
In epoch 10, loss: 1.742, val acc: 0.458 (best 0.458), test acc: 0.423 (best 0.423)
In epoch 15, loss: 1.528, val acc: 0.568 (best 0.568), test acc: 0.563 (best 0.563)
In epoch 20, loss: 1.246, val acc: 0.630 (best 0.630), test acc: 0.614 (best 0.614)
In epoch 25, loss: 0.926, val acc: 0.660 (best 0.660), test acc: 0.643 (best 0.643)
In epoch 30, loss: 0.620, val acc: 0.690 (best 0.690), test acc: 0.673 (best 0.660)
In epoch 35, loss: 0.377, val acc: 0.734 (best 0.734), test acc: 0.716 (best 0.716)
In epoch 40, loss: 0.215, val acc: 0.750 (best 0.750), test acc: 0.746 (best 0.741)
In epoch 45, loss: 0.121, val acc: 0.744 (best 0.750), test acc: 0.751 (best 0.741)
In epoch 50, loss: 0.071, val acc: 0.744 (best 0.750), test acc: 0.753 (best 0.741)
In epoch 55, loss: 0.044, val acc: 0.748 (best 0.750), test acc: 0.752 (best 0.741)
In epoch 60, loss: 0.030, val acc: 0.752 (best 0.752), test acc: 0.754 (best 0.752)
In epoch 65, loss: 0.022, val acc: 0.756 (best 0.756), test acc: 0.755 (best 0.755)
In epoch 70, loss: 0.017, val acc: 0.756 (best 0.756), test acc: 0.754 (best 0.755)
In epoch 75, loss: 0.014, val acc: 0.754 (best 0.756), test acc: 0.755 (best 0.755)
In epoch 80, loss: 0.012, val acc: 0.754 (best 0.756), test acc: 0.755 (best 0.755)
In epoch 85, loss: 0.011, val acc: 0.758 (best 0.758), test acc: 0.757 (best 0.757)
In epoch 90, loss: 0.009, val acc: 0.758 (best 0.758), test acc: 0.758 (best 0.757)
In epoch 95, loss: 0.008, val acc: 0.758 (best 0.758), test acc: 0.758 (best 0.757)
In epoch 100, loss: 0.008, val acc: 0.758 (best 0.758), test acc: 0.758 (best 0.757)
In epoch 105, loss: 0.007, val acc: 0.760 (best 0.760), test acc: 0.758 (best 0.758)
In epoch 110, loss: 0.007, val acc: 0.760 (best 0.760), test acc: 0.759 (best 0.758)
In epoch 115, loss: 0.006, val acc: 0.762 (best 0.762), test acc: 0.759 (best 0.759)
In epoch 120, loss: 0.006, val acc: 0.762 (best 0.762), test acc: 0.759 (best 0.759)
In epoch 125, loss: 0.005, val acc: 0.762 (best 0.762), test acc: 0.759 (best 0.759)
In epoch 130, loss: 0.005, val acc: 0.762 (best 0.762), test acc: 0.759 (best 0.759)
In epoch 135, loss: 0.005, val acc: 0.764 (best 0.764), test acc: 0.759 (best 0.759)
In epoch 140, loss: 0.005, val acc: 0.764 (best 0.764), test acc: 0.759 (best 0.759)
In epoch 145, loss: 0.004, val acc: 0.764 (best 0.764), test acc: 0.759 (best 0.759)
In epoch 150, loss: 0.004, val acc: 0.764 (best 0.764), test acc: 0.759 (best 0.759)
In epoch 155, loss: 0.004, val acc: 0.764 (best 0.764), test acc: 0.759 (best 0.759)
In epoch 160, loss: 0.004, val acc: 0.764 (best 0.764), test acc: 0.759 (best 0.759)
In epoch 165, loss: 0.004, val acc: 0.764 (best 0.764), test acc: 0.760 (best 0.759)
In epoch 170, loss: 0.003, val acc: 0.764 (best 0.764), test acc: 0.760 (best 0.759)
In epoch 175, loss: 0.003, val acc: 0.764 (best 0.764), test acc: 0.761 (best 0.759)
In epoch 180, loss: 0.003, val acc: 0.764 (best 0.764), test acc: 0.762 (best 0.759)
In epoch 185, loss: 0.003, val acc: 0.764 (best 0.764), test acc: 0.762 (best 0.759)
In epoch 190, loss: 0.003, val acc: 0.764 (best 0.764), test acc: 0.762 (best 0.759)
In epoch 195, loss: 0.003, val acc: 0.764 (best 0.764), test acc: 0.762 (best 0.759)
Even more customization by user-defined function¶
DGL allows user-defined message and reduce function for the maximal
expressiveness. Here is a user-defined message function that is
equivalent to fn.u_mul_e('h', 'w', 'm')
.
def u_mul_e_udf(edges):
return {'m' : edges.src['h'] * edges.data['w']}
edges
has three members: src
, data
and dst
, representing
the source node feature, edge feature, and destination node feature for
all edges.
You can also write your own reduce function. For example, the following
is equivalent to the builtin fn.mean('m', 'h_N')
function that averages
the incoming messages:
def mean_udf(nodes):
return {'h_N': nodes.mailbox['m'].mean(1)}
In short, DGL will group the nodes by their in-degrees, and for each group DGL stacks the incoming messages along the second dimension. You can then perform a reduction along the second dimension to aggregate messages.
For more details on customizing message and reduce function with user-defined function, please refer to the API reference.
Best practice of writing custom GNN modules¶
DGL recommends the following practice ranked by preference:
Use
dgl.nn
modules.Use
dgl.nn.functional
functions which contain lower-level complex operations such as computing a softmax for each node over incoming edges.Use
update_all
with builtin message and reduce functions.Use user-defined message or reduce functions.
What’s next?¶
# Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018
# sphinx_gallery_thumbnail_path = '_static/blitz_3_message_passing.png'
Total running time of the script: ( 0 minutes 12.251 seconds)