Source code for dgl.nn.pytorch.conv.tagconv

"""Torch Module for Topology Adaptive Graph Convolutional layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
from torch import nn

from .... import function as fn


[docs]class TAGConv(nn.Module): r"""Topology Adaptive Graph Convolutional layer from paper `Topology Adaptive Graph Convolutional Networks <https://arxiv.org/pdf/1710.10370.pdf>`__. .. math:: \mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\mathbf{X} \mathbf{\Theta}_{k}, where :math:`\mathbf{A}` denotes the adjacency matrix and :math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix. Parameters ---------- in_feats : int Input feature size. out_feats : int Output feature size. k: int, optional Number of hops :math: `k`. (default: 2) bias: bool, optional If True, adds a learnable bias to the output. Default: ``True``. activation: callable activation function/layer or None, optional If not None, applies an activation function to the updated node features. Default: ``None``. Attributes ---------- lin : torch.Module The learnable linear module. """ def __init__(self, in_feats, out_feats, k=2, bias=True, activation=None): super(TAGConv, self).__init__() self._in_feats = in_feats self._out_feats = out_feats self._k = k self._activation = activation self.lin = nn.Linear(in_feats * (self._k + 1), out_feats, bias=bias) self.reset_parameters() def reset_parameters(self): """Reinitialize learnable parameters.""" gain = nn.init.calculate_gain('relu') nn.init.xavier_normal_(self.lin.weight, gain=gain)
[docs] def forward(self, graph, feat): r"""Compute topology adaptive graph convolution. Parameters ---------- graph : DGLGraph The graph. feat : torch.Tensor The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. Returns ------- torch.Tensor The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is size of output feature. """ assert graph.is_homograph(), 'Graph is not homogeneous' graph = graph.local_var() norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) shp = norm.shape + (1,) * (feat.dim() - 1) norm = th.reshape(norm, shp).to(feat.device) #D-1/2 A D -1/2 X fstack = [feat] for _ in range(self._k): rst = fstack[-1] * norm graph.ndata['h'] = rst graph.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='h')) rst = graph.ndata['h'] rst = rst * norm fstack.append(rst) rst = self.lin(th.cat(fstack, dim=-1)) if self._activation is not None: rst = self._activation(rst) return rst