"""Torch modules for graph attention networks v2 (GATv2)."""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
from torch import nn
from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError
from ..utils import Identity
from ....utils import expand_as_pair
# pylint: enable=W0235
[docs]class GATv2Conv(nn.Module):
r"""GATv2 from `How Attentive are Graph Attention Networks?
<https://arxiv.org/pdf/2105.14491.pdf>`__
.. math::
h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{ij}^{(l)} W^{(l)}_{right} h_j^{(l)}
where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and
node :math:`j`:
.. math::
\alpha_{ij}^{(l)} &= \mathrm{softmax_i} (e_{ij}^{(l)})
e_{ij}^{(l)} &= {\vec{a}^T}^{(l)}\mathrm{LeakyReLU}\left(
W^{(l)}_{left} h_{i} + W^{(l)}_{right} h_{j}\right)
Parameters
----------
in_feats : int, or pair of ints
Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
If the layer is to be applied to a unidirectional bipartite graph, `in_feats`
specifies the input feature size on both the source and destination nodes.
If a scalar is given, the source and destination node feature size
would take the same value.
out_feats : int
Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
num_heads : int
Number of heads in Multi-Head Attention.
feat_drop : float, optional
Dropout rate on feature. Defaults: ``0``.
attn_drop : float, optional
Dropout rate on attention weight. Defaults: ``0``.
negative_slope : float, optional
LeakyReLU angle of negative slope. Defaults: ``0.2``.
residual : bool, optional
If True, use residual connection. Defaults: ``False``.
activation : callable activation function/layer or None, optional.
If not None, applies an activation function to the updated node features.
Default: ``None``.
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.
bias : bool, optional
If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
share_weights : bool, optional
If set to :obj:`True`, the same matrix for :math:`W_{left}` and :math:`W_{right}` in
the above equations, will be applied to the source and the target node of every edge.
(default: :obj:`False`)
Note
----
Zero in-degree nodes will lead to invalid output value. This is because no message
will be passed to those nodes, the aggregation function will be applied on empty input.
A common practice to avoid this is to add a self-loop for each node in the graph if
it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph
>>> g = dgl.add_self_loop(g)
Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph
since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``
to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.
A common practise to handle this is to filter out the nodes with zero-in-degree when use
after conv.
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import GATv2Conv
>>> # Case 1: Homogeneous graph
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = th.ones(6, 10)
>>> gatv2conv = GATv2Conv(10, 2, num_heads=3)
>>> res = gatv2conv(g, feat)
>>> res
tensor([[[ 1.9599, 1.0239],
[ 3.2015, -0.5512],
[ 2.3700, -2.2182]],
[[ 1.9599, 1.0239],
[ 3.2015, -0.5512],
[ 2.3700, -2.2182]],
[[ 1.9599, 1.0239],
[ 3.2015, -0.5512],
[ 2.3700, -2.2182]],
[[ 1.9599, 1.0239],
[ 3.2015, -0.5512],
[ 2.3700, -2.2182]],
[[ 1.9599, 1.0239],
[ 3.2015, -0.5512],
[ 2.3700, -2.2182]],
[[ 1.9599, 1.0239],
[ 3.2015, -0.5512],
[ 2.3700, -2.2182]]], grad_fn=<GSpMMBackward>)
>>> # Case 2: Unidirectional bipartite graph
>>> u = [0, 1, 0, 0, 1]
>>> v = [0, 1, 2, 3, 2]
>>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})
>>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32))
>>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))
>>> gatv2conv = GATv2Conv((5,10), 2, 3)
>>> res = gatv2conv(g, (u_feat, v_feat))
>>> res
tensor([[[-0.0935, -0.4273],
[-1.1850, 0.1123],
[-0.2002, 0.1155]],
[[ 0.1908, -1.2095],
[-0.0129, 0.6408],
[-0.8135, 0.1157]],
[[ 0.0596, -0.8487],
[-0.5421, 0.4022],
[-0.4805, 0.1156]],
[[-0.0935, -0.4273],
[-1.1850, 0.1123],
[-0.2002, 0.1155]]], grad_fn=<GSpMMBackward>)
"""
def __init__(self,
in_feats,
out_feats,
num_heads,
feat_drop=0.,
attn_drop=0.,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
bias=True,
share_weights=False):
super(GATv2Conv, self).__init__()
self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias)
self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=bias)
else:
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias)
if share_weights:
self.fc_dst = self.fc_src
else:
self.fc_dst = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias)
self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
if residual:
if self._in_dst_feats != out_feats * num_heads:
self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=bias)
else:
self.res_fc = Identity()
else:
self.register_buffer('res_fc', None)
self.activation = activation
self.share_weights = share_weights
self.bias = bias
self.reset_parameters()
[docs] def reset_parameters(self):
"""
Description
-----------
Reinitialize learnable parameters.
Note
----
The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
The attention weights are using xavier initialization method.
"""
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
if self.bias:
nn.init.constant_(self.fc_src.bias, 0)
if not self.share_weights:
nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
if self.bias:
nn.init.constant_(self.fc_dst.bias, 0)
nn.init.xavier_normal_(self.attn, gain=gain)
if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
if self.bias:
nn.init.constant_(self.res_fc.bias, 0)
def set_allow_zero_in_degree(self, set_value):
r"""
Description
-----------
Set allow_zero_in_degree flag.
Parameters
----------
set_value : bool
The value to be set to the flag.
"""
self._allow_zero_in_degree = set_value
[docs] def forward(self, graph, feat, get_attention=False):
r"""
Description
-----------
Compute graph attention network layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
get_attention : bool, optional
Whether to return the attention values. Default to False.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature.
torch.Tensor, optional
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
edges. This is returned only when :attr:`get_attention` is ``True``.
Raises
------
DGLError
If there are 0-in-degree nodes in the input graph, it will raise DGLError
since no message will be passed to those nodes. This will cause invalid output.
The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.
"""
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
else:
h_src = h_dst = self.feat_drop(feat)
feat_src = self.fc_src(h_src).view(
-1, self._num_heads, self._out_feats)
if self.share_weights:
feat_dst = feat_src
else:
feat_dst = self.fc_dst(h_src).view(
-1, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
graph.srcdata.update({'el': feat_src})# (num_src_edge, num_heads, out_dim)
graph.dstdata.update({'er': feat_dst})
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))# (num_src_edge, num_heads, out_dim)
e = (e * self.attn).sum(dim=-1).unsqueeze(dim=2)# (num_edge, num_heads, 1)
# compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) # (num_edge, num_heads)
# message passing
graph.update_all(fn.u_mul_e('el', 'a', 'm'),
fn.sum('m', 'ft'))
rst = graph.dstdata['ft']
# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
rst = rst + resval
# activation
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
else:
return rst