Source code for dgl.nn.pytorch.graph_transformer

"""Torch modules for graph transformers."""
import math

import torch as th
import torch.nn as nn
import torch.nn.functional as F

from ...batch import unbatch
from ...convert import to_homogeneous
from ...transforms import shortest_dist

__all__ = [
    "DegreeEncoder",
    "BiasedMultiheadAttention",
    "PathEncoder",
    "GraphormerLayer",
    "SpatialEncoder",
    "SpatialEncoder3d",
]


[docs]class DegreeEncoder(nn.Module): r"""Degree Encoder, as introduced in `Do Transformers Really Perform Bad for Graph Representation? <https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__ This module is a learnable degree embedding module. Parameters ---------- max_degree : int Upper bound of degrees to be encoded. Each degree will be clamped into the range [0, ``max_degree``]. embedding_dim : int Output dimension of embedding vectors. direction : str, optional Degrees of which direction to be encoded, selected from ``in``, ``out`` and ``both``. ``both`` encodes degrees from both directions and output the addition of them. Default : ``both``. Example ------- >>> import dgl >>> from dgl.nn import DegreeEncoder >>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1])) >>> degree_encoder = DegreeEncoder(5, 16) >>> degree_embedding = degree_encoder(g) """ def __init__(self, max_degree, embedding_dim, direction="both"): super(DegreeEncoder, self).__init__() self.direction = direction if direction == "both": self.degree_encoder_1 = nn.Embedding( max_degree + 1, embedding_dim, padding_idx=0 ) self.degree_encoder_2 = nn.Embedding( max_degree + 1, embedding_dim, padding_idx=0 ) else: self.degree_encoder = nn.Embedding( max_degree + 1, embedding_dim, padding_idx=0 ) self.max_degree = max_degree
[docs] def forward(self, g): """ Parameters ---------- g : DGLGraph A DGLGraph to be encoded. If it is a heterogeneous one, it will be transformed into a homogeneous one first. Returns ------- Tensor Return degree embedding vectors of shape :math:`(N, embedding_dim)`, where :math:`N` is th number of nodes in the input graph. """ if len(g.ntypes) > 1 or len(g.etypes) > 1: g = to_homogeneous(g) in_degree = th.clamp(g.in_degrees(), min=0, max=self.max_degree) out_degree = th.clamp(g.out_degrees(), min=0, max=self.max_degree) if self.direction == "in": degree_embedding = self.degree_encoder(in_degree) elif self.direction == "out": degree_embedding = self.degree_encoder(out_degree) elif self.direction == "both": degree_embedding = self.degree_encoder_1( in_degree ) + self.degree_encoder_2(out_degree) else: raise ValueError( f'Supported direction options: "in", "out" and "both", ' f"but got {self.direction}" ) return degree_embedding
[docs]class PathEncoder(nn.Module): r"""Path Encoder, as introduced in Edge Encoding of `Do Transformers Really Perform Bad for Graph Representation? <https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__ This module is a learnable path embedding module and encodes the shortest path between each pair of nodes as attention bias. Parameters ---------- max_len : int Maximum number of edges in each path to be encoded. Exceeding part of each path will be truncated, i.e. truncating edges with serial number no less than :attr:`max_len`. feat_dim : int Dimension of edge features in the input graph. num_heads : int, optional Number of attention heads if multi-head attention mechanism is applied. Default : 1. Examples -------- >>> import torch as th >>> import dgl >>> from dgl.nn import PathEncoder >>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3]) >>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1]) >>> g = dgl.graph((u, v)) >>> edata = th.rand(8, 16) >>> path_encoder = PathEncoder(2, 16, num_heads=8) >>> out = path_encoder(g, edata) """ def __init__(self, max_len, feat_dim, num_heads=1): super().__init__() self.max_len = max_len self.feat_dim = feat_dim self.num_heads = num_heads self.embedding_table = nn.Embedding(max_len * num_heads, feat_dim)
[docs] def forward(self, g, edge_feat): """ Parameters ---------- g : DGLGraph A DGLGraph to be encoded, which must be a homogeneous one. edge_feat : torch.Tensor The input edge feature of shape :math:`(E, feat_dim)`, where :math:`E` is the number of edges in the input graph. Returns ------- torch.Tensor Return attention bias as path encoding, of shape :math:`(batch_size, N, N, num_heads)`, where :math:`N` is the maximum number of nodes and batch_size is the batch size of the input graph. """ g_list = unbatch(g) sum_num_edges = 0 max_num_nodes = th.max(g.batch_num_nodes()) path_encoding = [] for ubg in g_list: num_nodes = ubg.num_nodes() num_edges = ubg.num_edges() edata = edge_feat[sum_num_edges : (sum_num_edges + num_edges)] sum_num_edges = sum_num_edges + num_edges edata = th.cat( (edata, th.zeros(1, self.feat_dim).to(edata.device)), dim=0 ) dist, path = shortest_dist(ubg, root=None, return_paths=True) path_len = max(1, min(self.max_len, path.size(dim=2))) # shape: [n, n, l], n = num_nodes, l = path_len shortest_path = path[:, :, 0:path_len] # shape: [n, n] shortest_distance = th.clamp(dist, min=1, max=path_len) # shape: [n, n, l, d], d = feat_dim path_data = edata[shortest_path] # shape: [l, h, d] edge_embedding = self.embedding_table.weight[ 0 : path_len * self.num_heads ].reshape(path_len, self.num_heads, -1) # [n, n, l, d] einsum [l, h, d] -> [n, n, h] # [n, n, h] -> [N, N, h], N = max_num_nodes, padded with -inf sub_encoding = th.full( (max_num_nodes, max_num_nodes, self.num_heads), float("-inf") ) sub_encoding[0:num_nodes, 0:num_nodes] = th.div( th.einsum("xyld,lhd->xyh", path_data, edge_embedding).permute( 2, 0, 1 ), shortest_distance, ).permute(1, 2, 0) path_encoding.append(sub_encoding) return th.stack(path_encoding, dim=0)
[docs]class BiasedMultiheadAttention(nn.Module): r"""Dense Multi-Head Attention Module with Graph Attention Bias. Compute attention between nodes with attention bias obtained from graph structures, as introduced in `Do Transformers Really Perform Bad for Graph Representation? <https://arxiv.org/pdf/2106.05234>`__ .. math:: \text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b) :math:`Q` and :math:`K` are feature representation of nodes. :math:`d` is the corresponding :attr:`feat_size`. :math:`b` is attention bias, which can be additive or multiplicative according to the operator :math:`\circ`. Parameters ---------- feat_size : int Feature size. num_heads : int Number of attention heads, by which attr:`feat_size` is divisible. bias : bool, optional If True, it uses bias for linear projection. Default: True. attn_bias_type : str, optional The type of attention bias used for modifying attention. Selected from 'add' or 'mul'. Default: 'add'. * 'add' is for additive attention bias. * 'mul' is for multiplicative attention bias. attn_drop : float, optional Dropout probability on attention weights. Defalt: 0.1. Examples -------- >>> import torch as th >>> from dgl.nn import BiasedMultiheadAttention >>> ndata = th.rand(16, 100, 512) >>> bias = th.rand(16, 100, 100, 8) >>> net = BiasedMultiheadAttention(feat_size=512, num_heads=8) >>> out = net(ndata, bias) """ def __init__( self, feat_size, num_heads, bias=True, attn_bias_type="add", attn_drop=0.1, ): super().__init__() self.feat_size = feat_size self.num_heads = num_heads self.head_dim = feat_size // num_heads assert ( self.head_dim * num_heads == feat_size ), "feat_size must be divisible by num_heads" self.scaling = self.head_dim**-0.5 self.attn_bias_type = attn_bias_type self.q_proj = nn.Linear(feat_size, feat_size, bias=bias) self.k_proj = nn.Linear(feat_size, feat_size, bias=bias) self.v_proj = nn.Linear(feat_size, feat_size, bias=bias) self.out_proj = nn.Linear(feat_size, feat_size, bias=bias) self.dropout = nn.Dropout(p=attn_drop) self.reset_parameters()
[docs] def reset_parameters(self): """Reset parameters of projection matrices, the same settings as that in Graphormer.""" nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-0.5) nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-0.5) nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-0.5) nn.init.xavier_uniform_(self.out_proj.weight) if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0)
[docs] def forward(self, ndata, attn_bias=None, attn_mask=None): """Forward computation. Parameters ---------- ndata : torch.Tensor A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where N is the maximum number of nodes. attn_bias : torch.Tensor, optional The attention bias used for attention modification. Shape: (batch_size, N, N, :attr:`num_heads`). attn_mask : torch.Tensor, optional The attention mask used for avoiding computation on invalid positions, where invalid positions are indicated by non-zero values. Shape: (batch_size, N, N). Returns ------- y : torch.Tensor The output tensor. Shape: (batch_size, N, :attr:`feat_size`) """ q_h = self.q_proj(ndata).transpose(0, 1) k_h = self.k_proj(ndata).transpose(0, 1) v_h = self.v_proj(ndata).transpose(0, 1) bsz, N, _ = ndata.shape q_h = ( q_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(0, 1) / self.scaling ) k_h = k_h.reshape(N, bsz * self.num_heads, self.head_dim).permute( 1, 2, 0 ) v_h = v_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose( 0, 1 ) attn_weights = ( th.bmm(q_h, k_h) .transpose(0, 2) .reshape(N, N, bsz, self.num_heads) .transpose(0, 2) ) if attn_bias is not None: if self.attn_bias_type == "add": attn_weights += attn_bias else: attn_weights *= attn_bias if attn_mask is not None: attn_weights[attn_mask.to(th.bool)] = float("-inf") attn_weights = F.softmax( attn_weights.transpose(0, 2) .reshape(N, N, bsz * self.num_heads) .transpose(0, 2), dim=2, ) attn_weights = self.dropout(attn_weights) attn = th.bmm(attn_weights, v_h).transpose(0, 1) attn = self.out_proj( attn.reshape(N, bsz, self.feat_size).transpose(0, 1) ) return attn
[docs]class GraphormerLayer(nn.Module): r"""Graphormer Layer with Dense Multi-Head Attention, as introduced in `Do Transformers Really Perform Bad for Graph Representation? <https://arxiv.org/pdf/2106.05234>`__ Parameters ---------- feat_size : int Feature size. hidden_size : int Hidden size of feedforward layers. num_heads : int Number of attention heads, by which :attr:`feat_size` is divisible. attn_bias_type : str, optional The type of attention bias used for modifying attention. Selected from 'add' or 'mul'. Default: 'add'. * 'add' is for additive attention bias. * 'mul' is for multiplicative attention bias. norm_first : bool, optional If True, it performs layer normalization before attention and feedforward operations. Otherwise, it applies layer normalization afterwards. Default: False. dropout : float, optional Dropout probability. Default: 0.1. activation : callable activation layer, optional Activation function. Default: nn.ReLU(). Examples -------- >>> import torch as th >>> from dgl.nn import GraphormerLayer >>> batch_size = 16 >>> num_nodes = 100 >>> feat_size = 512 >>> num_heads = 8 >>> nfeat = th.rand(batch_size, num_nodes, feat_size) >>> bias = th.rand(batch_size, num_nodes, num_nodes, num_heads) >>> net = GraphormerLayer( feat_size=feat_size, hidden_size=2048, num_heads=num_heads ) >>> out = net(nfeat, bias) """ def __init__( self, feat_size, hidden_size, num_heads, attn_bias_type="add", norm_first=False, dropout=0.1, activation=nn.ReLU(), ): super().__init__() self.norm_first = norm_first self.attn = BiasedMultiheadAttention( feat_size=feat_size, num_heads=num_heads, attn_bias_type=attn_bias_type, attn_drop=dropout, ) self.ffn = nn.Sequential( nn.Linear(feat_size, hidden_size), activation, nn.Dropout(p=dropout), nn.Linear(hidden_size, feat_size), nn.Dropout(p=dropout), ) self.dropout = nn.Dropout(p=dropout) self.attn_layer_norm = nn.LayerNorm(feat_size) self.ffn_layer_norm = nn.LayerNorm(feat_size)
[docs] def forward(self, nfeat, attn_bias=None, attn_mask=None): """Forward computation. Parameters ---------- nfeat : torch.Tensor A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where N is the maximum number of nodes. attn_bias : torch.Tensor, optional The attention bias used for attention modification. Shape: (batch_size, N, N, :attr:`num_heads`). attn_mask : torch.Tensor, optional The attention mask used for avoiding computation on invalid positions. Shape: (batch_size, N, N). Returns ------- y : torch.Tensor The output tensor. Shape: (batch_size, N, :attr:`feat_size`) """ residual = nfeat if self.norm_first: nfeat = self.attn_layer_norm(nfeat) nfeat = self.attn(nfeat, attn_bias, attn_mask) nfeat = self.dropout(nfeat) nfeat = residual + nfeat if not self.norm_first: nfeat = self.attn_layer_norm(nfeat) residual = nfeat if self.norm_first: nfeat = self.ffn_layer_norm(nfeat) nfeat = self.ffn(nfeat) nfeat = residual + nfeat if not self.norm_first: nfeat = self.ffn_layer_norm(nfeat) return nfeat
[docs]class SpatialEncoder(nn.Module): r"""Spatial Encoder, as introduced in `Do Transformers Really Perform Bad for Graph Representation? <https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__ This module is a learnable spatial embedding module which encodes the shortest distance between each node pair for attention bias. Parameters ---------- max_dist : int Upper bound of the shortest path distance between each node pair to be encoded. All distance will be clamped into the range `[0, max_dist]`. num_heads : int, optional Number of attention heads if multi-head attention mechanism is applied. Default : 1. Examples -------- >>> import torch as th >>> import dgl >>> from dgl.nn import SpatialEncoder >>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3]) >>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1]) >>> g = dgl.graph((u, v)) >>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8) >>> out = spatial_encoder(g) >>> print(out.shape) torch.Size([1, 4, 4, 8]) """ def __init__(self, max_dist, num_heads=1): super().__init__() self.max_dist = max_dist self.num_heads = num_heads # deactivate node pair between which the distance is -1 self.embedding_table = nn.Embedding( max_dist + 2, num_heads, padding_idx=0 )
[docs] def forward(self, g): """ Parameters ---------- g : DGLGraph A DGLGraph to be encoded, which must be a homogeneous one. Returns ------- torch.Tensor Return attention bias as spatial encoding of shape :math:`(B, N, N, H)`, where :math:`N` is the maximum number of nodes, :math:`B` is the batch size of the input graph, and :math:`H` is :attr:`num_heads`. """ device = g.device g_list = unbatch(g) max_num_nodes = th.max(g.batch_num_nodes()) spatial_encoding = [] for ubg in g_list: num_nodes = ubg.num_nodes() dist = ( th.clamp( shortest_dist(ubg, root=None, return_paths=False), min=-1, max=self.max_dist, ) + 1 ) # shape: [n, n, h], n = num_nodes, h = num_heads dist_embedding = self.embedding_table(dist) # [n, n, h] -> [N, N, h], N = max_num_nodes, padded with -inf padded_encoding = th.full( (max_num_nodes, max_num_nodes, self.num_heads), float("-inf") ).to(device) padded_encoding[0:num_nodes, 0:num_nodes] = dist_embedding spatial_encoding.append(padded_encoding) return th.stack(spatial_encoding, dim=0)
[docs]class SpatialEncoder3d(nn.Module): r"""3D Spatial Encoder, as introduced in `One Transformer Can Understand Both 2D & 3D Molecular Data <https://arxiv.org/pdf/2210.01765.pdf>`__ This module encodes pair-wise relation between atom pair :math:`(i,j)` in the 3D geometric space, according to the Gaussian Basis Kernel function: :math:`\psi _{(i,j)} ^k = -\frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert} \exp{\left ( -\frac{1}{2} \left( \frac{\gamma_{(i,j)} \lvert \lvert r_i - r_j \rvert \rvert + \beta_{(i,j)} - \mu^k}{\lvert \sigma^k \rvert} \right) ^2 \right)},k=1,...,K,` where :math:`K` is the number of Gaussian Basis kernels. :math:`r_i` is the Cartesian coordinate of atom :math:`i`. :math:`\gamma_{(i,j)}, \beta_{(i,j)}` are learnable scaling factors of the Gaussian Basis kernels. Parameters ---------- num_kernels : int Number of Gaussian Basis Kernels to be applied. Each Gaussian Basis Kernel contains a learnable kernel center and a learnable scaling factor. num_heads : int, optional Number of attention heads if multi-head attention mechanism is applied. Default : 1. max_node_type : int, optional Maximum number of node types. Default : 1. Examples -------- >>> import torch as th >>> import dgl >>> from dgl.nn import SpatialEncoder3d >>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3]) >>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1]) >>> g = dgl.graph((u, v)) >>> coordinate = th.rand(4, 3) >>> node_type = th.tensor([1, 0, 2, 1]) >>> spatial_encoder = SpatialEncoder3d(num_kernels=4, ... num_heads=8, ... max_node_type=3) >>> out = spatial_encoder(g, coordinate, node_type=node_type) >>> print(out.shape) torch.Size([1, 4, 4, 8]) """ def __init__(self, num_kernels, num_heads=1, max_node_type=1): super().__init__() self.num_kernels = num_kernels self.num_heads = num_heads self.max_node_type = max_node_type self.gaussian_means = nn.Embedding(1, num_kernels) self.gaussian_stds = nn.Embedding(1, num_kernels) self.linear_layer_1 = nn.Linear(num_kernels, num_kernels) self.linear_layer_2 = nn.Linear(num_kernels, num_heads) if max_node_type == 1: self.mul = nn.Embedding(1, 1) self.bias = nn.Embedding(1, 1) else: self.mul = nn.Embedding(max_node_type + 1, 2) self.bias = nn.Embedding(max_node_type + 1, 2) nn.init.uniform_(self.gaussian_means.weight, 0, 3) nn.init.uniform_(self.gaussian_stds.weight, 0, 3) nn.init.constant_(self.mul.weight, 0) nn.init.constant_(self.bias.weight, 1)
[docs] def forward(self, g, coord, node_type=None): """ Parameters ---------- g : DGLGraph A DGLGraph to be encoded, which must be a homogeneous one. coord : torch.Tensor 3D coordinates of nodes in :attr:`g`, of shape :math:`(N, 3)`, where :math:`N`: is the number of nodes in :attr:`g`. node_type : torch.Tensor, optional Node types of :attr:`g`. Default : None. * If :attr:`max_node_type` is not 1, :attr:`node_type` needs to be a tensor in shape :math:`(N,)`. The scaling factors of each pair of nodes are determined by their node types. * Otherwise, :attr:`node_type` should be None. Returns ------- torch.Tensor Return attention bias as 3D spatial encoding of shape :math:`(B, n, n, H)`, where :math:`B` is the batch size, :math:`n` is the maximum number of nodes in unbatched graphs from :attr:`g`, and :math:`H` is :attr:`num_heads`. """ device = g.device g_list = unbatch(g) max_num_nodes = th.max(g.batch_num_nodes()) spatial_encoding = [] sum_num_nodes = 0 if (self.max_node_type == 1) != (node_type is None): raise ValueError( "input node_type should be None if and only if " "max_node_type is 1." ) for ubg in g_list: num_nodes = ubg.num_nodes() sub_coord = coord[sum_num_nodes : sum_num_nodes + num_nodes] # shape: [n, n], n = num_nodes euc_dist = th.cdist(sub_coord, sub_coord, p=2) if node_type is None: # shape: [1] mul = self.mul.weight[0, 0] bias = self.bias.weight[0, 0] else: sub_node_type = node_type[ sum_num_nodes : sum_num_nodes + num_nodes ] mul_embedding = self.mul(sub_node_type) bias_embedding = self.bias(sub_node_type) # shape: [n, n] mul = mul_embedding[:, 0].unsqueeze(-1).repeat( 1, num_nodes ) + mul_embedding[:, 1].unsqueeze(0).repeat(num_nodes, 1) bias = bias_embedding[:, 0].unsqueeze(-1).repeat( 1, num_nodes ) + bias_embedding[:, 1].unsqueeze(0).repeat(num_nodes, 1) # shape: [n, n, k], k = num_kernels scaled_dist = ( (mul * euc_dist + bias) .repeat(self.num_kernels, 1, 1) .permute((1, 2, 0)) ) # shape: [k] gaussian_mean = self.gaussian_means.weight.float().view(-1) gaussian_var = ( self.gaussian_stds.weight.float().view(-1).abs() + 1e-2 ) # shape: [n, n, k] gaussian_kernel = ( ( -0.5 * ( th.div( scaled_dist - gaussian_mean, gaussian_var ).square() ) ) .exp() .div(-math.sqrt(2 * math.pi) * gaussian_var) ) encoding = self.linear_layer_1(gaussian_kernel) encoding = F.gelu(encoding) # [n, n, k] -> [n, n, a], a = num_heads encoding = self.linear_layer_2(encoding) # [n, n, a] -> [N, N, a], N = max_num_nodes, padded with -inf padded_encoding = th.full( (max_num_nodes, max_num_nodes, self.num_heads), float("-inf") ).to(device) padded_encoding[0:num_nodes, 0:num_nodes] = encoding spatial_encoding.append(padded_encoding) sum_num_nodes += num_nodes return th.stack(spatial_encoding, dim=0)