Source code for

"""Graphormer Layer"""

import torch.nn as nn

from .biased_mha import BiasedMHA

[docs]class GraphormerLayer(nn.Module): r"""Graphormer Layer with Dense Multi-Head Attention, as introduced in `Do Transformers Really Perform Bad for Graph Representation? <>`__ Parameters ---------- feat_size : int Feature size. hidden_size : int Hidden size of feedforward layers. num_heads : int Number of attention heads, by which :attr:`feat_size` is divisible. attn_bias_type : str, optional The type of attention bias used for modifying attention. Selected from 'add' or 'mul'. Default: 'add'. * 'add' is for additive attention bias. * 'mul' is for multiplicative attention bias. norm_first : bool, optional If True, it performs layer normalization before attention and feedforward operations. Otherwise, it applies layer normalization afterwards. Default: False. dropout : float, optional Dropout probability. Default: 0.1. attn_dropout : float, optional Attention dropout probability. Default: 0.1. activation : callable activation layer, optional Activation function. Default: nn.ReLU(). Examples -------- >>> import torch as th >>> from dgl.nn import GraphormerLayer >>> batch_size = 16 >>> num_nodes = 100 >>> feat_size = 512 >>> num_heads = 8 >>> nfeat = th.rand(batch_size, num_nodes, feat_size) >>> bias = th.rand(batch_size, num_nodes, num_nodes, num_heads) >>> net = GraphormerLayer( feat_size=feat_size, hidden_size=2048, num_heads=num_heads ) >>> out = net(nfeat, bias) """ def __init__( self, feat_size, hidden_size, num_heads, attn_bias_type="add", norm_first=False, dropout=0.1, attn_dropout=0.1, activation=nn.ReLU(), ): super().__init__() self.norm_first = norm_first self.attn = BiasedMHA( feat_size=feat_size, num_heads=num_heads, attn_bias_type=attn_bias_type, attn_drop=attn_dropout, ) self.ffn = nn.Sequential( nn.Linear(feat_size, hidden_size), activation, nn.Dropout(p=dropout), nn.Linear(hidden_size, feat_size), nn.Dropout(p=dropout), ) self.dropout = nn.Dropout(p=dropout) self.attn_layer_norm = nn.LayerNorm(feat_size) self.ffn_layer_norm = nn.LayerNorm(feat_size)
[docs] def forward(self, nfeat, attn_bias=None, attn_mask=None): """Forward computation. Parameters ---------- nfeat : torch.Tensor A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where N is the maximum number of nodes. attn_bias : torch.Tensor, optional The attention bias used for attention modification. Shape: (batch_size, N, N, :attr:`num_heads`). attn_mask : torch.Tensor, optional The attention mask used for avoiding computation on invalid positions, where invalid positions are indicated by `True` values. Shape: (batch_size, N, N). Note: For rows corresponding to unexisting nodes, make sure at least one entry is set to `False` to prevent obtaining NaNs with softmax. Returns ------- y : torch.Tensor The output tensor. Shape: (batch_size, N, :attr:`feat_size`) """ residual = nfeat if self.norm_first: nfeat = self.attn_layer_norm(nfeat) nfeat = self.attn(nfeat, attn_bias, attn_mask) nfeat = self.dropout(nfeat) nfeat = residual + nfeat if not self.norm_first: nfeat = self.attn_layer_norm(nfeat) residual = nfeat if self.norm_first: nfeat = self.ffn_layer_norm(nfeat) nfeat = self.ffn(nfeat) nfeat = residual + nfeat if not self.norm_first: nfeat = self.ffn_layer_norm(nfeat) return nfeat