Source code for

"""Biased Multi-head Attention"""

import torch as th
import torch.nn as nn
import torch.nn.functional as F

[docs]class BiasedMHA(nn.Module): r"""Dense Multi-Head Attention Module with Graph Attention Bias. Compute attention between nodes with attention bias obtained from graph structures, as introduced in `Do Transformers Really Perform Bad for Graph Representation? <>`__ .. math:: \text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b) :math:`Q` and :math:`K` are feature representations of nodes. :math:`d` is the corresponding :attr:`feat_size`. :math:`b` is attention bias, which can be additive or multiplicative according to the operator :math:`\circ`. Parameters ---------- feat_size : int Feature size. num_heads : int Number of attention heads, by which :attr:`feat_size` is divisible. bias : bool, optional If True, it uses bias for linear projection. Default: True. attn_bias_type : str, optional The type of attention bias used for modifying attention. Selected from 'add' or 'mul'. Default: 'add'. * 'add' is for additive attention bias. * 'mul' is for multiplicative attention bias. attn_drop : float, optional Dropout probability on attention weights. Defalt: 0.1. Examples -------- >>> import torch as th >>> from dgl.nn import BiasedMHA >>> ndata = th.rand(16, 100, 512) >>> bias = th.rand(16, 100, 100, 8) >>> net = BiasedMHA(feat_size=512, num_heads=8) >>> out = net(ndata, bias) """ def __init__( self, feat_size, num_heads, bias=True, attn_bias_type="add", attn_drop=0.1, ): super().__init__() self.feat_size = feat_size self.num_heads = num_heads self.head_dim = feat_size // num_heads assert ( self.head_dim * num_heads == feat_size ), "feat_size must be divisible by num_heads" self.scaling = self.head_dim**-0.5 self.attn_bias_type = attn_bias_type self.q_proj = nn.Linear(feat_size, feat_size, bias=bias) self.k_proj = nn.Linear(feat_size, feat_size, bias=bias) self.v_proj = nn.Linear(feat_size, feat_size, bias=bias) self.out_proj = nn.Linear(feat_size, feat_size, bias=bias) self.dropout = nn.Dropout(p=attn_drop) self.reset_parameters()
[docs] def reset_parameters(self): """ Initialize parameters of projection matrices, the same settings as in the original implementation of the paper. """ nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-0.5) nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-0.5) nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-0.5) nn.init.xavier_uniform_(self.out_proj.weight) if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0)
[docs] def forward(self, ndata, attn_bias=None, attn_mask=None): """Forward computation. Parameters ---------- ndata : torch.Tensor A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where N is the maximum number of nodes. attn_bias : torch.Tensor, optional The attention bias used for attention modification. Shape: (batch_size, N, N, :attr:`num_heads`). attn_mask : torch.Tensor, optional The attention mask used for avoiding computation on invalid positions, where invalid positions are indicated by `True` values. Shape: (batch_size, N, N). Note: For rows corresponding to unexisting nodes, make sure at least one entry is set to `False` to prevent obtaining NaNs with softmax. Returns ------- y : torch.Tensor The output tensor. Shape: (batch_size, N, :attr:`feat_size`) """ q_h = self.q_proj(ndata).transpose(0, 1) k_h = self.k_proj(ndata).transpose(0, 1) v_h = self.v_proj(ndata).transpose(0, 1) bsz, N, _ = ndata.shape q_h = ( q_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(0, 1) * self.scaling ) k_h = k_h.reshape(N, bsz * self.num_heads, self.head_dim).permute( 1, 2, 0 ) v_h = v_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose( 0, 1 ) attn_weights = ( th.bmm(q_h, k_h) .transpose(0, 2) .reshape(N, N, bsz, self.num_heads) .transpose(0, 2) ) if attn_bias is not None: if self.attn_bias_type == "add": attn_weights += attn_bias else: attn_weights *= attn_bias if attn_mask is not None: attn_weights[] = float("-inf") attn_weights = F.softmax( attn_weights.transpose(0, 2) .reshape(N, N, bsz * self.num_heads) .transpose(0, 2), dim=2, ) attn_weights = self.dropout(attn_weights) attn = th.bmm(attn_weights, v_h).transpose(0, 1) attn = self.out_proj( attn.reshape(N, bsz, self.feat_size).transpose(0, 1) ) return attn