Graph Transformer in a Nutshell

The Transformer (Vaswani et al. 2017) has been proven an effective learning architecture in natural language processing and computer vision. Recently, researchers turns to explore the application of transformer in graph learning. They have achieved inital success on many practical tasks, e.g., graph property prediction. Dwivedi et al. (2020) firstly generalize the transformer neural architecture to graph-structured data. Here, we present how to build such a graph transformer with DGL’s sparse matrix APIs.

Open In Colab GitHub

[ ]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

# Uncomment below to install required packages. If the CUDA version is not 11.8,
# check the https://www.dgl.ai/pages/start.html to find the supported CUDA
# version and corresponding command to install DGL.
#!pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html > /dev/null
#!pip install ogb >/dev/null

try:
    import dgl
    installed = True
except ImportError:
    installed = False
print("DGL installed!" if installed else "Failed to install DGL!")

Sparse Multi-head Attention

Recall the all-pairs scaled-dot-product attention mechanism in vanillar Transformer:

\[\text{Attn}=\text{softmax}(\dfrac{QK^T} {\sqrt{d}})V,\]

The graph transformer (GT) model employs a Sparse Multi-head Attention block:

\[\text{SparseAttn}(Q, K, V, A) = \text{softmax}(\frac{(QK^T) \circ A}{\sqrt{d}})V,\]

where \(Q, K, V ∈\mathbb{R}^{N\times d}\) are query feature, key feature, and value feature, respectively. \(A\in[0,1]^{N\times N}\) is the adjacency matrix of the input graph. \((QK^T)\circ A\) means that the multiplication of query matrix and key matrix is followed by a Hadamard product (or element-wise multiplication) with the sparse adjacency matrix as illustrated in the figure below:

6e78ff7a9d2b48c8a55563ba929a2e6e

Essentially, only the attention scores between connected nodes are computed according to the sparsity of \(A\). This operation is also called Sampled Dense Dense Matrix Multiplication (SDDMM).

Enjoying the batched SDDMM API in DGL, we can parallel the computation on multiple attention heads (different representation subspaces).

[ ]:
import dgl
import dgl.nn as dglnn
import dgl.sparse as dglsp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from dgl.data import AsGraphPredDataset
from dgl.dataloading import GraphDataLoader
from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator
from ogb.graphproppred.mol_encoder import AtomEncoder
from tqdm import tqdm


class SparseMHA(nn.Module):
    """Sparse Multi-head Attention Module"""

    def __init__(self, hidden_size=80, num_heads=8):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.scaling = self.head_dim**-0.5

        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, A, h):
        N = len(h)
        # [N, dh, nh]
        q = self.q_proj(h).reshape(N, self.head_dim, self.num_heads)
        q *= self.scaling
        # [N, dh, nh]
        k = self.k_proj(h).reshape(N, self.head_dim, self.num_heads)
        # [N, dh, nh]
        v = self.v_proj(h).reshape(N, self.head_dim, self.num_heads)

        ######################################################################
        # (HIGHLIGHT) Compute the multi-head attention with Sparse Matrix API
        ######################################################################
        attn = dglsp.bsddmm(A, q, k.transpose(1, 0))  # (sparse) [N, N, nh]
        # Sparse softmax by default applies on the last sparse dimension.
        attn = attn.softmax()  # (sparse) [N, N, nh]
        out = dglsp.bspmm(attn, v)  # [N, dh, nh]

        return self.out_proj(out.reshape(N, -1))

Graph Transformer Layer

The GT layer is composed of Multi-head Attention, Batch Norm, and Feed-forward Network, connected by residual links as in vanilla transformer.

2c65706341fa4a2eb709559bb412421c

[ ]:
class GTLayer(nn.Module):
    """Graph Transformer Layer"""

    def __init__(self, hidden_size=80, num_heads=8):
        super().__init__()
        self.MHA = SparseMHA(hidden_size=hidden_size, num_heads=num_heads)
        self.batchnorm1 = nn.BatchNorm1d(hidden_size)
        self.batchnorm2 = nn.BatchNorm1d(hidden_size)
        self.FFN1 = nn.Linear(hidden_size, hidden_size * 2)
        self.FFN2 = nn.Linear(hidden_size * 2, hidden_size)

    def forward(self, A, h):
        h1 = h
        h = self.MHA(A, h)
        h = self.batchnorm1(h + h1)

        h2 = h
        h = self.FFN2(F.relu(self.FFN1(h)))
        h = h2 + h

        return self.batchnorm2(h)

Graph Transformer Model

The GT model is constructed by stacking GT layers. The input positional encoding of vanilla transformer is replaced with Laplacian positional encoding (Dwivedi et al. 2020). For the graph-level prediction task, an extra pooler is stacked on top of GT layers to aggregate node feature of the same graph.

[ ]:
class GTModel(nn.Module):
    def __init__(
        self,
        out_size,
        hidden_size=80,
        pos_enc_size=2,
        num_layers=8,
        num_heads=8,
    ):
        super().__init__()
        self.atom_encoder = AtomEncoder(hidden_size)
        self.pos_linear = nn.Linear(pos_enc_size, hidden_size)
        self.layers = nn.ModuleList(
            [GTLayer(hidden_size, num_heads) for _ in range(num_layers)]
        )
        self.pooler = dglnn.SumPooling()
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.ReLU(),
            nn.Linear(hidden_size // 4, out_size),
        )

    def forward(self, g, X, pos_enc):
        indices = torch.stack(g.edges())
        N = g.num_nodes()
        A = dglsp.spmatrix(indices, shape=(N, N))
        h = self.atom_encoder(X) + self.pos_linear(pos_enc)
        for layer in self.layers:
            h = layer(A, h)
        h = self.pooler(g, h)

        return self.predictor(h)

Training

We train the GT model on ogbg-molhiv benchmark. The Laplacian positional encoding of each graph is pre-computed (with the API here) as part of the input to the model.

Note that we down-sample the dataset to make this demo runs faster. See the example script for the performance on the full dataset.

[ ]:
@torch.no_grad()
def evaluate(model, dataloader, evaluator, device):
    model.eval()
    y_true = []
    y_pred = []
    for batched_g, labels in dataloader:
        batched_g, labels = batched_g.to(device), labels.to(device)
        y_hat = model(batched_g, batched_g.ndata["feat"], batched_g.ndata["PE"])
        y_true.append(labels.view(y_hat.shape).detach().cpu())
        y_pred.append(y_hat.detach().cpu())
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()
    input_dict = {"y_true": y_true, "y_pred": y_pred}
    return evaluator.eval(input_dict)["rocauc"]


def train(model, dataset, evaluator, device):
    train_dataloader = GraphDataLoader(
        dataset[dataset.train_idx],
        batch_size=256,
        shuffle=True,
        collate_fn=collate_dgl,
    )
    valid_dataloader = GraphDataLoader(
        dataset[dataset.val_idx], batch_size=256, collate_fn=collate_dgl
    )
    test_dataloader = GraphDataLoader(
        dataset[dataset.test_idx], batch_size=256, collate_fn=collate_dgl
    )
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    num_epochs = 20
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=num_epochs, gamma=0.5
    )
    loss_fcn = nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for batched_g, labels in train_dataloader:
            batched_g, labels = batched_g.to(device), labels.to(device)
            logits = model(
                batched_g, batched_g.ndata["feat"], batched_g.ndata["PE"]
            )
            loss = loss_fcn(logits, labels.float())
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()
        avg_loss = total_loss / len(train_dataloader)
        val_metric = evaluate(model, valid_dataloader, evaluator, device)
        test_metric = evaluate(model, test_dataloader, evaluator, device)
        print(
            f"Epoch: {epoch:03d}, Loss: {avg_loss:.4f}, "
            f"Val: {val_metric:.4f}, Test: {test_metric:.4f}"
        )


# Training device.
dev = torch.device("cpu")
# Uncomment the code below to train on GPU. Be sure to install DGL with CUDA support.
#dev = torch.device("cuda:0")

# Load dataset.
pos_enc_size = 8
dataset = AsGraphPredDataset(
    DglGraphPropPredDataset("ogbg-molhiv", "./data/OGB")
)
evaluator = Evaluator("ogbg-molhiv")

# Down sample the dataset to make the tutorial run faster.
import random
random.seed(42)
train_size = len(dataset.train_idx)
val_size = len(dataset.val_idx)
test_size = len(dataset.test_idx)
dataset.train_idx = dataset.train_idx[
    torch.LongTensor(random.sample(range(train_size), 2000))
]
dataset.val_idx = dataset.val_idx[
    torch.LongTensor(random.sample(range(val_size), 1000))
]
dataset.test_idx = dataset.test_idx[
    torch.LongTensor(random.sample(range(test_size), 1000))
]

# Laplacian positional encoding.
indices = torch.cat([dataset.train_idx, dataset.val_idx, dataset.test_idx])
for idx in tqdm(indices, desc="Computing Laplacian PE"):
    g, _ = dataset[idx]
    g.ndata["PE"] = dgl.laplacian_pe(g, k=pos_enc_size, padding=True)

# Create model.
out_size = dataset.num_tasks
model = GTModel(out_size=out_size, pos_enc_size=pos_enc_size).to(dev)

# Kick off training.
train(model, dataset, evaluator, dev)
Computing Laplacian PE:   1%|          | 25/4000 [00:00<00:16, 244.77it/s]/usr/local/lib/python3.8/dist-packages/dgl/backend/pytorch/tensor.py:52: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:250.)
  return th.as_tensor(data, dtype=dtype)
Computing Laplacian PE: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4000/4000 [00:13<00:00, 296.04it/s]
Epoch: 000, Loss: 0.2486, Val: 0.3082, Test: 0.3068
Epoch: 001, Loss: 0.1695, Val: 0.4684, Test: 0.4572
Epoch: 002, Loss: 0.1428, Val: 0.5887, Test: 0.4721
Epoch: 003, Loss: 0.1237, Val: 0.6375, Test: 0.5010
Epoch: 004, Loss: 0.1127, Val: 0.6628, Test: 0.4854
Epoch: 005, Loss: 0.1047, Val: 0.6811, Test: 0.4983
Epoch: 006, Loss: 0.0949, Val: 0.6751, Test: 0.5409
Epoch: 007, Loss: 0.0901, Val: 0.6340, Test: 0.5357
Epoch: 008, Loss: 0.0811, Val: 0.6717, Test: 0.5543
Epoch: 009, Loss: 0.0643, Val: 0.7861, Test: 0.5628
Epoch: 010, Loss: 0.0489, Val: 0.7319, Test: 0.5341
Epoch: 011, Loss: 0.0340, Val: 0.7884, Test: 0.5299
Epoch: 012, Loss: 0.0285, Val: 0.5887, Test: 0.4293
Epoch: 013, Loss: 0.0361, Val: 0.5514, Test: 0.3419
Epoch: 014, Loss: 0.0451, Val: 0.6795, Test: 0.4964
Epoch: 015, Loss: 0.0429, Val: 0.7405, Test: 0.5527
Epoch: 016, Loss: 0.0331, Val: 0.7859, Test: 0.4994
Epoch: 017, Loss: 0.0177, Val: 0.6544, Test: 0.4457
Epoch: 018, Loss: 0.0201, Val: 0.8250, Test: 0.6073
Epoch: 019, Loss: 0.0093, Val: 0.7356, Test: 0.5561