Source code for dgl.nn.pytorch.conv.gatconv

"""Torch modules for graph attention networks(GAT)."""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
from torch import nn

from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError
from ..utils import Identity
from ....utils import expand_as_pair

# pylint: enable=W0235
[docs]class GATConv(nn.Module): r"""Graph attention layer from `Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ .. math:: h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)} where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and node :math:`j`: .. math:: \alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l}) e_{ij}^{l} &= \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right) Parameters ---------- in_feats : int, or pair of ints Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`. GATConv can be applied on homogeneous graph and unidirectional `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__. If the layer is to be applied to a unidirectional bipartite graph, ``in_feats`` specifies the input feature size on both the source and destination nodes. If a scalar is given, the source and destination node feature size would take the same value. out_feats : int Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`. num_heads : int Number of heads in Multi-Head Attention. feat_drop : float, optional Dropout rate on feature. Defaults: ``0``. attn_drop : float, optional Dropout rate on attention weight. Defaults: ``0``. negative_slope : float, optional LeakyReLU angle of negative slope. Defaults: ``0.2``. residual : bool, optional If True, use residual connection. Defaults: ``False``. activation : callable activation function/layer or None, optional. If not None, applies an activation function to the updated node features. Default: ``None``. allow_zero_in_degree : bool, optional If there are 0-in-degree nodes in the graph, output for those nodes will be invalid since no message will be passed to those nodes. This is harmful for some applications causing silent performance regression. This module will raise a DGLError if it detects 0-in-degree nodes in input graph. By setting ``True``, it will suppress the check and let the users handle it by themselves. Defaults: ``False``. bias : bool, optional If True, learns a bias term. Defaults: ``True``. Note ---- Zero in-degree nodes will lead to invalid output value. This is because no message will be passed to those nodes, the aggregation function will be appied on empty input. A common practice to avoid this is to add a self-loop for each node in the graph if it is homogeneous, which can be achieved by: >>> g = ... # a DGLGraph >>> g = dgl.add_self_loop(g) Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree`` to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually. A common practise to handle this is to filter out the nodes with zero-in-degree when use after conv. Examples -------- >>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import GATConv >>> # Case 1: Homogeneous graph >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.add_self_loop(g) >>> feat = th.ones(6, 10) >>> gatconv = GATConv(10, 2, num_heads=3) >>> res = gatconv(g, feat) >>> res tensor([[[ 3.4570, 1.8634], [ 1.3805, -0.0762], [ 1.0390, -1.1479]], [[ 3.4570, 1.8634], [ 1.3805, -0.0762], [ 1.0390, -1.1479]], [[ 3.4570, 1.8634], [ 1.3805, -0.0762], [ 1.0390, -1.1479]], [[ 3.4570, 1.8634], [ 1.3805, -0.0762], [ 1.0390, -1.1479]], [[ 3.4570, 1.8634], [ 1.3805, -0.0762], [ 1.0390, -1.1479]], [[ 3.4570, 1.8634], [ 1.3805, -0.0762], [ 1.0390, -1.1479]]], grad_fn=<BinaryReduceBackward>) >>> # Case 2: Unidirectional bipartite graph >>> u = [0, 1, 0, 0, 1] >>> v = [0, 1, 2, 3, 2] >>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)}) >>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32)) >>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32)) >>> gatconv = GATConv((5,10), 2, 3) >>> res = gatconv(g, (u_feat, v_feat)) >>> res tensor([[[-0.6066, 1.0268], [-0.5945, -0.4801], [ 0.1594, 0.3825]], [[ 0.0268, 1.0783], [ 0.5041, -1.3025], [ 0.6568, 0.7048]], [[-0.2688, 1.0543], [-0.0315, -0.9016], [ 0.3943, 0.5347]], [[-0.6066, 1.0268], [-0.5945, -0.4801], [ 0.1594, 0.3825]]], grad_fn=<BinaryReduceBackward>) """ def __init__(self, in_feats, out_feats, num_heads, feat_drop=0., attn_drop=0., negative_slope=0.2, residual=False, activation=None, allow_zero_in_degree=False, bias=True): super(GATConv, self).__init__() self._num_heads = num_heads self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._out_feats = out_feats self._allow_zero_in_degree = allow_zero_in_degree if isinstance(in_feats, tuple): self.fc_src = nn.Linear( self._in_src_feats, out_feats * num_heads, bias=False) self.fc_dst = nn.Linear( self._in_dst_feats, out_feats * num_heads, bias=False) else: self.fc = nn.Linear( self._in_src_feats, out_feats * num_heads, bias=False) self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) self.feat_drop = nn.Dropout(feat_drop) self.attn_drop = nn.Dropout(attn_drop) self.leaky_relu = nn.LeakyReLU(negative_slope) if bias: self.bias = nn.Parameter(th.FloatTensor(size=(num_heads * out_feats,))) else: self.register_buffer('bias', None) if residual: if self._in_dst_feats != out_feats * num_heads: self.res_fc = nn.Linear( self._in_dst_feats, num_heads * out_feats, bias=False) else: self.res_fc = Identity() else: self.register_buffer('res_fc', None) self.reset_parameters() self.activation = activation
[docs] def reset_parameters(self): """ Description ----------- Reinitialize learnable parameters. Note ---- The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization. The attention weights are using xavier initialization method. """ gain = nn.init.calculate_gain('relu') if hasattr(self, 'fc'): nn.init.xavier_normal_(self.fc.weight, gain=gain) else: nn.init.xavier_normal_(self.fc_src.weight, gain=gain) nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) nn.init.xavier_normal_(self.attn_l, gain=gain) nn.init.xavier_normal_(self.attn_r, gain=gain) if self.bias is not None: nn.init.constant_(self.bias, 0) if isinstance(self.res_fc, nn.Linear): nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def set_allow_zero_in_degree(self, set_value): r""" Description ----------- Set allow_zero_in_degree flag. Parameters ---------- set_value : bool The value to be set to the flag. """ self._allow_zero_in_degree = set_value
[docs] def forward(self, graph, feat, get_attention=False): r""" Description ----------- Compute graph attention network layer. Parameters ---------- graph : DGLGraph The graph. feat : torch.Tensor or pair of torch.Tensor If a torch.Tensor is given, the input feature of shape :math:`(N, *, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. If a pair of torch.Tensor is given, the pair must contain two tensors of shape :math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`. get_attention : bool, optional Whether to return the attention values. Default to False. Returns ------- torch.Tensor The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H` is the number of heads, and :math:`D_{out}` is size of output feature. torch.Tensor, optional The attention values of shape :math:`(E, *, H, 1)`, where :math:`E` is the number of edges. This is returned only when :attr:`get_attention` is ``True``. Raises ------ DGLError If there are 0-in-degree nodes in the input graph, it will raise DGLError since no message will be passed to those nodes. This will cause invalid output. The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``. """ with graph.local_scope(): if not self._allow_zero_in_degree: if (graph.in_degrees() == 0).any(): raise DGLError('There are 0-in-degree nodes in the graph, ' 'output for those nodes will be invalid. ' 'This is harmful for some applications, ' 'causing silent performance regression. ' 'Adding self-loop on the input graph by ' 'calling `g = dgl.add_self_loop(g)` will resolve ' 'the issue. Setting ``allow_zero_in_degree`` ' 'to be `True` when constructing this module will ' 'suppress the check and let the code run.') if isinstance(feat, tuple): src_prefix_shape = feat[0].shape[:-1] dst_prefix_shape = feat[1].shape[:-1] h_src = self.feat_drop(feat[0]) h_dst = self.feat_drop(feat[1]) if not hasattr(self, 'fc_src'): feat_src = self.fc(h_src).view( *src_prefix_shape, self._num_heads, self._out_feats) feat_dst = self.fc(h_dst).view( *dst_prefix_shape, self._num_heads, self._out_feats) else: feat_src = self.fc_src(h_src).view( *src_prefix_shape, self._num_heads, self._out_feats) feat_dst = self.fc_dst(h_dst).view( *dst_prefix_shape, self._num_heads, self._out_feats) else: src_prefix_shape = dst_prefix_shape = feat.shape[:-1] h_src = h_dst = self.feat_drop(feat) feat_src = feat_dst = self.fc(h_src).view( *src_prefix_shape, self._num_heads, self._out_feats) if graph.is_block: feat_dst = feat_src[:graph.number_of_dst_nodes()] h_dst = h_dst[:graph.number_of_dst_nodes()] dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:] # NOTE: GAT paper uses "first concatenation then linear projection" # to compute attention scores, while ours is "first projection then # addition", the two approaches are mathematically equivalent: # We decompose the weight vector a mentioned in the paper into # [a_l || a_r], then # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j # Our implementation is much efficient because we do not need to # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, # addition could be optimized with DGL's built-in function u_add_v, # which further speeds up computation and saves memory footprint. el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) graph.srcdata.update({'ft': feat_src, 'el': el}) graph.dstdata.update({'er': er}) # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. graph.apply_edges(fn.u_add_v('el', 'er', 'e')) e = self.leaky_relu(graph.edata.pop('e')) # compute softmax graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) # message passing graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft')) rst = graph.dstdata['ft'] # residual if self.res_fc is not None: # Use -1 rather than self._num_heads to handle broadcasting resval = self.res_fc(h_dst).view(*dst_prefix_shape, -1, self._out_feats) rst = rst + resval # bias if self.bias is not None: rst = rst + self.bias.view( *((1,) * len(dst_prefix_shape)), self._num_heads, self._out_feats) # activation if self.activation: rst = self.activation(rst) if get_attention: return rst, graph.edata['a'] else: return rst