Graph Diffusion in Graph Neural Networks

This tutorial first briefly introduces the diffusion process on graphs. It then illustrates how Graph Neural Networks can utilize this concept to enhance prediction power.

Open In Colab GitHub

[ ]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

# Uncomment below to install required packages. If the CUDA version is not 11.8,
# check the https://www.dgl.ai/pages/start.html to find the supported CUDA
# version and corresponding command to install DGL.
#!pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html > /dev/null
#!pip install --upgrade scipy networkx > /dev/null

try:
    import dgl
    installed = True
except ImportError:
    installed = False
print("DGL installed!" if installed else "Failed to install DGL!")

Graph Diffusion

Diffusion describes the process of substances moving from one region to another. In the context of graph, the diffusing substances (e.g., real-value signals) travel along edges from nodes to nodes.

Mathematically, let \(\vec x\) be the vector of node signals, then a graph diffusion operation can be defined as:

\[\vec{y} = \tilde{A} \vec{x}\]

, where \(\tilde{A}\) is the diffusion matrix that is typically derived from the adjacency matrix of the graph. Although the selection of diffusion matrices may vary, the diffusion matrix is typically sparse and \(\tilde{A} \vec{x}\) is thus a sparse-dense matrix multiplication.

Let us understand it more with a simple example. First, we obtain the adjacency matrix of the famous Karate Club Network.

[ ]:
import dgl
import dgl.sparse as dglsp
from dgl.data import KarateClubDataset

# Get the graph from DGL's builtin dataset.
dataset = KarateClubDataset()
dgl_g = dataset[0]

# Get its adjacency matrix.
indices = torch.stack(dgl_g.edges())
N = dgl_g.num_nodes()
A = dglsp.spmatrix(indices, shape=(N, N))
print(A.to_dense())
tensor([[0., 1., 1.,  ..., 1., 0., 0.],
        [1., 0., 1.,  ..., 0., 0., 0.],
        [1., 1., 0.,  ..., 0., 1., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 1., 1.],
        [0., 0., 1.,  ..., 1., 0., 1.],
        [0., 0., 0.,  ..., 1., 1., 0.]])

We use the graph convolution matrix from Graph Convolution Networks as the diffusion matrix in this example. The graph convolution matrix is defined as:

\[\tilde{A} = \bar{D}^{-\frac{1}{2}}\bar{A}\bar{D}^{-\frac{1}{2}}\]

with \(\bar{A} = A + I\), where \(A\) denotes the adjacency matrix and \(I\) denotes the identity matrix, \(\bar{D}\) refers to the diagonal node degree matrix of \(\bar{A}\).

[ ]:
# Compute graph convolution matrix.
I = dglsp.identity(A.shape)
A_hat = A + I
D_hat = dglsp.diag(A_hat.sum(dim=1))
D_hat_invsqrt = D_hat ** -0.5
A_tilde = D_hat_invsqrt @ A_hat @ D_hat_invsqrt
print(A_tilde.to_dense())
tensor([[0.0588, 0.0767, 0.0731,  ..., 0.0917, 0.0000, 0.0000],
        [0.0767, 0.1000, 0.0953,  ..., 0.0000, 0.0000, 0.0000],
        [0.0731, 0.0953, 0.0909,  ..., 0.0000, 0.0836, 0.0000],
        ...,
        [0.0917, 0.0000, 0.0000,  ..., 0.1429, 0.1048, 0.0891],
        [0.0000, 0.0000, 0.0836,  ..., 0.1048, 0.0769, 0.0654],
        [0.0000, 0.0000, 0.0000,  ..., 0.0891, 0.0654, 0.0556]])

For node signals, we set all nodes but one to be zero.

[ ]:
# Initial node signals. All nodes except one are set to zero.
X = torch.zeros(N)
X[0] = 5.

# Number of diffusion steps.
r = 8

# Record the signals after each diffusion step.
results = [X]
for _ in range(r):
    X = A_tilde @ X
    results.append(X)

The program below visualizes the diffusion process with animation. To play the animation, click the β€œplay” icon. You will see how node features converge over time.

[ ]:
import matplotlib.pyplot as plt
import networkx as nx
from IPython.display import HTML
from matplotlib import animation

nx_g = dgl_g.to_networkx().to_undirected()
pos = nx.spring_layout(nx_g)

fig, ax = plt.subplots()
plt.close()

def animate(i):
    ax.cla()
    # Color nodes based on their features.
    nodes = nx.draw_networkx_nodes(nx_g, pos, ax=ax, node_size=200, node_color=results[i].tolist(), cmap=plt.cm.Blues)
    # Set boundary color of the nodes.
    nodes.set_edgecolor("#000000")
    nx.draw_networkx_edges(nx_g, pos, ax=ax)

ani = animation.FuncAnimation(fig, animate, frames=len(results), interval=1000)
HTML(ani.to_jshtml())

Graph Diffusion in GNNs

Scalable Inception Graph Neural Networks (SIGN) leverages multiple diffusion operators simultaneously. Formally, it is defined as:

\[\begin{split}Z=\sigma([X\Theta_{0},A_1X\Theta_{1},\cdots,A_rX\Theta_{r}])\\ Y=\xi(Z\Omega)\end{split}\]

where: * \(\sigma\) and \(\xi\) are nonlinear activation functions. * \([\cdot,\cdots,\cdot]\) is the concatenation operation. * \(X\in\mathbb{R}^{n\times d}\) is the input node feature matrix with \(n\) nodes and \(d\)-dimensional feature vector per node. * \(\Theta_0,\cdots,\Theta_r\in\mathbb{R}^{d\times d'}\) are learnable weight matrices. * \(A_1,\cdots, A_r\in\mathbb{R}^{n\times n}\) are linear diffusion operators. In the example below, we consider \(A^i\) for \(A_i\), where \(A\) is the convolution matrix of the graph. - \(\Omega\in\mathbb{R}^{d'(r+1)\times c}\) is a learnable weight matrix and \(c\) is the number of classes.

The code below implements the diffusion function to compute \(A_1X, A_2X, \cdots, A_rX\) and the module that combines all the diffused node features.

[ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F


################################################################################
# (HIGHLIGHT) Take the advantage of DGL sparse APIs to implement the feature
# diffusion in SIGN laconically.
################################################################################
def sign_diffusion(A, X, r):
    # Perform the r-hop diffusion operation.
    X_sign = [X]
    for i in range(r):
        # A^i X
        X = A @ X
        X_sign.append(X)
    return X_sign

class SIGN(nn.Module):
    def __init__(self, in_size, out_size, r, hidden_size=256):
        super().__init__()
        self.theta = nn.ModuleList(
            [nn.Linear(in_size, hidden_size) for _ in range(r + 1)]
        )
        self.omega = nn.Linear(hidden_size * (r + 1), out_size)

    def forward(self, X_sign):
        results = []
        for i in range(len(X_sign)):
            results.append(self.theta[i](X_sign[i]))
        Z = F.relu(torch.cat(results, dim=1))
        return self.omega(Z)

Training

We train the SIGN model on Cora dataset. The node features are diffused in the pre-processing stage.

[ ]:
from dgl.data import CoraGraphDataset
from torch.optim import Adam


def evaluate(g, pred):
    label = g.ndata["label"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]

    # Compute accuracy on validation/test set.
    val_acc = (pred[val_mask] == label[val_mask]).float().mean()
    test_acc = (pred[test_mask] == label[test_mask]).float().mean()
    return val_acc, test_acc


def train(model, g, X_sign):
    label = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    optimizer = Adam(model.parameters(), lr=3e-3)

    for epoch in range(10):
        # Switch the model to training mode.
        model.train()

        # Forward.
        logits = model(X_sign)

        # Compute loss with nodes in training set.
        loss = F.cross_entropy(logits[train_mask], label[train_mask])

        # Backward.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Switch the model to evaluating mode.
        model.eval()

        # Compute prediction.
        logits = model(X_sign)
        pred = logits.argmax(1)

        # Evaluate the prediction.
        val_acc, test_acc = evaluate(g, pred)
        print(
            f"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}, test"
            f" acc: {test_acc:.3f}"
        )


# If CUDA is available, use GPU to accelerate the training, use CPU
# otherwise.
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load graph from the existing dataset.
dataset = CoraGraphDataset()
g = dataset[0].to(dev)

# Create the sparse adjacency matrix A (note that W was used as the notation
# for adjacency matrix in the original paper).
indices = torch.stack(g.edges())
N = g.num_nodes()
A = dglsp.spmatrix(indices, shape=(N, N))

# Calculate the graph convolution matrix.
I = dglsp.identity(A.shape, device=dev)
A_hat = A + I
D_hat_invsqrt = dglsp.diag(A_hat.sum(dim=1)) ** -0.5
A_hat = D_hat_invsqrt @ A_hat @ D_hat_invsqrt

# 2-hop diffusion.
r = 2
X = g.ndata["feat"]
X_sign = sign_diffusion(A_hat, X, r)

# Create SIGN model.
in_size = X.shape[1]
out_size = dataset.num_classes
model = SIGN(in_size, out_size, r).to(dev)

# Kick off training.
train(model, g, X_sign)
Downloading /root/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...
Extracting file to /root/.dgl/cora_v2
Finished data loading and preprocessing.
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.
In epoch 0, loss: 1.946, val acc: 0.164, test acc: 0.200
In epoch 1, loss: 1.937, val acc: 0.712, test acc: 0.690
In epoch 2, loss: 1.926, val acc: 0.610, test acc: 0.595
In epoch 3, loss: 1.914, val acc: 0.656, test acc: 0.640
In epoch 4, loss: 1.898, val acc: 0.724, test acc: 0.726
In epoch 5, loss: 1.880, val acc: 0.734, test acc: 0.753
In epoch 6, loss: 1.859, val acc: 0.730, test acc: 0.746
In epoch 7, loss: 1.834, val acc: 0.732, test acc: 0.743
In epoch 8, loss: 1.807, val acc: 0.734, test acc: 0.746
In epoch 9, loss: 1.776, val acc: 0.734, test acc: 0.745

Check out the full example script here. Learn more about how graph diffusion is used in other GNN models:

  • Predict then Propagate: Graph Neural Networks meet Personalized PageRank paper code

  • Combining Label Propagation and Simple Models Out-performs Graph Neural Networks paper code

  • Simplifying Graph Convolutional Networks paper code

  • Graph Neural Networks Inspired by Classical Iterative Algorithms paper code