Source code for dgl.nn.pytorch.conv.twirlsconv

"""Torch modules for TWIRLS"""
# pylint: disable=invalid-name, useless-super-delegation, no-member

import torch as tc
import torch.nn as nn
import torch.nn.functional as F

from .... import function as fn


[docs]class TWIRLSConv(nn.Module): r"""Convolution together with iteratively reweighting least squre from `Graph Neural Networks Inspired by Classical Iterative Algorithms <https://arxiv.org/pdf/2103.06064.pdf>`__ Parameters ---------- input_d : int Number of input features. output_d : int Number of output features. hidden_d : int Size of hidden layers. prop_step : int Number of propagation steps num_mlp_before : int Number of mlp layers before propagation. Default: ``1``. num_mlp_after : int Number of mlp layers after propagation. Default: ``1``. norm : str The type of norm layers inside mlp layers. Can be ``'batch'``, ``'layer'`` or ``'none'``. Default: ``'none'`` precond : str If True, use pre conditioning and unormalized laplacian, else not use pre conditioning and use normalized laplacian. Default: ``True`` alp : float The :math:`\alpha` in paper. If equal to :math:`0`, will be automatically decided based on other hyper prameters. Default: ``0``. lam : float The :math:`\lambda` in paper. Default: ``1``. attention : bool If ``True``, add an attention layer inside propagations. Default: ``False``. tau : float The :math:`\tau` in paper. Default: ``0.2``. T : float The :math:`T` in paper. If < 0, :math:`T` will be set to `\infty`. Default: ``-1``. p : float The :math:`p` in paper. Default: ``1``. use_eta : bool If ``True``, add a learnable weight on each dimension in attention. Default: ``False``. attn_bef : bool If ``True``, add another attention layer before propagation. Default: ``False``. dropout : float The dropout rate in mlp layers. Default: ``0.0``. attn_dropout : float The dropout rate of attention values. Default: ``0.0``. inp_dropout : float The dropout rate on input features. Default: ``0.0``. Note ---- ``add_self_loop`` will be automatically called before propagation. Example ------- >>> import dgl >>> from dgl.nn import TWIRLSConv >>> import torch as th >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = TWIRLSConv(10, 2, 128, prop_step = 64) >>> res = conv(g , feat) >>> res.size() torch.Size([6, 2]) """ def __init__( self, input_d, output_d, hidden_d, prop_step, num_mlp_before=1, num_mlp_after=1, norm="none", precond=True, alp=0, lam=1, attention=False, tau=0.2, T=-1, p=1, use_eta=False, attn_bef=False, dropout=0.0, attn_dropout=0.0, inp_dropout=0.0, ): super().__init__() self.input_d = input_d self.output_d = output_d self.hidden_d = hidden_d self.prop_step = prop_step self.num_mlp_before = num_mlp_before self.num_mlp_after = num_mlp_after self.norm = norm self.precond = precond self.attention = attention self.alp = alp self.lam = lam self.tau = tau self.T = T self.p = p self.use_eta = use_eta self.init_att = attn_bef self.dropout = dropout self.attn_dropout = attn_dropout self.inp_dropout = inp_dropout # ----- initialization of some variables ----- # where to put attention self.attn_aft = prop_step // 2 if attention else -1 # whether we can cache unfolding result self.cacheable = ( (not self.attention) and self.num_mlp_before == 0 and self.inp_dropout <= 0 ) if self.cacheable: self.cached_unfolding = None # if only one layer, then no hidden size self.size_bef_unf = self.hidden_d self.size_aft_unf = self.hidden_d if self.num_mlp_before == 0: self.size_aft_unf = self.input_d # as the input of mlp_aft if self.num_mlp_after == 0: self.size_bef_unf = self.output_d # as the output of mlp_bef # ----- computational modules ----- self.mlp_bef = MLP( self.input_d, self.hidden_d, self.size_bef_unf, self.num_mlp_before, self.dropout, self.norm, init_activate=False, ) self.unfolding = TWIRLSUnfoldingAndAttention( self.hidden_d, self.alp, self.lam, self.prop_step, self.attn_aft, self.tau, self.T, self.p, self.use_eta, self.init_att, self.attn_dropout, self.precond, ) # if there are really transformations before unfolding, then do init_activate in mlp_aft self.mlp_aft = MLP( self.size_aft_unf, self.hidden_d, self.output_d, self.num_mlp_after, self.dropout, self.norm, init_activate=(self.num_mlp_before > 0) and (self.num_mlp_after > 0), )
[docs] def forward(self, graph, feat): r""" Description ----------- Run TWIRLS forward. Parameters ---------- graph : DGLGraph The graph. feat : torch.Tensor The initial node features. Returns ------- torch.Tensor The output feature Note ---- * Input shape: :math:`(N, \text{input_d})` where :math:`N` is the number of nodes. * Output shape: :math:`(N, \text{output_d})`. """ # ensure self loop graph = graph.remove_self_loop() graph = graph.add_self_loop() x = feat if self.cacheable: # to cache unfolding result becase there is no paramaters before it if self.cached_unfolding is None: self.cached_unfolding = self.unfolding(graph, x) x = self.cached_unfolding else: if self.inp_dropout > 0: x = F.dropout(x, self.inp_dropout, training=self.training) x = self.mlp_bef(x) x = self.unfolding(graph, x) x = self.mlp_aft(x) return x
class Propagate(nn.Module): r""" Description ----------- The propagation method which is with pre-conditioning and reparameterizing. Correspond to eq.28 in the paper. """ def __init__(self): super().__init__() def _prop(self, graph, Y, lam): """propagation part.""" Y = D_power_bias_X(graph, Y, -0.5, lam, 1 - lam) Y = AX(graph, Y) Y = D_power_bias_X(graph, Y, -0.5, lam, 1 - lam) return Y def forward(self, graph, Y, X, alp, lam): r""" Description ----------- Propagation forward. Parameters ---------- graph : DGLGraph The graph. Y : torch.Tensor The feature under propagation. Corresponds to :math:`Z^{(k)}` in eq.28 in the paper. X : torch.Tensor The original feature. Corresponds to :math:`Z^{(0)}` in eq.28 in the paper. alp : float The step size. Corresponds to :math:`\alpha` in the paper. lam : torch.Tensor The coefficient of smoothing term. Corresponds to :math:`\lambda` in the paper. Returns ------- torch.Tensor Propagated feature. :math:`Z^{(k+1)}` in eq.28 in the paper. """ return ( (1 - alp) * Y + alp * lam * self._prop(graph, Y, lam) + alp * D_power_bias_X(graph, X, -1, lam, 1 - lam) ) class PropagateNoPrecond(nn.Module): r""" Description ----------- The propagation method which is without pre-conditioning and reparameterizing and using normalized laplacian. Correspond to eq.30 in the paper. """ def __init__(self): super().__init__() def forward(self, graph, Y, X, alp, lam): r""" Description ----------- Propagation forward. Parameters ---------- graph : DGLGraph The graph. Y : torch.Tensor The feature under propagation. Corresponds to :math:`Y^{(k)}` in eq.30 in the paper. X : torch.Tensor The original feature. Corresponds to :math:`Y^{(0)}` in eq.30 in the paper. alp : float The step size. Corresponds to :math:`\alpha` in the paper. lam : torch.Tensor The coefficient of smoothing term. Corresponds to :math:`\lambda` in the paper. Returns ------- torch.Tensor Propagated feature. :math:`Y^{(k+1)}` in eq.30 in the paper. """ return ( (1 - alp * lam - alp) * Y + alp * lam * normalized_AX(graph, Y) + alp * X ) class Attention(nn.Module): r""" Description ----------- The attention function. Correspond to :math:`s` in eq.27 the paper. Parameters ---------- tau : float The lower thresholding parameter. Correspond to :math:`\tau` in the paper. T : float The upper thresholding parameter. Correspond to :math:`T` in the paper. p : float Correspond to :math:`\rho` in the paper.. attn_dropout : float the dropout rate of attention value. Default: ``0.0``. Returns ------- torch.Tensor The output feature """ def __init__(self, tau, T, p, attn_dropout=0.0): super().__init__() self.tau = tau self.T = T self.p = p self.attn_dropout = attn_dropout def reweighting(self, graph): """Compute graph edge weight. Would be stored in ``graph.edata['w']``""" w = graph.edata["w"] # It is not activation here but to ensure w > 0. # w can be < 0 here because of some precision issue in dgl, which causes NaN afterwards. w = F.relu(w) + 1e-7 w = tc.pow(w, 1 - 0.5 * self.p) w[(w < self.tau)] = self.tau if self.T > 0: w[(w > self.T)] = float("inf") w = 1 / w # if not (w == w).all(): # raise "nan occured!" graph.edata["w"] = w + 1e-9 # avoid 0 degree def forward(self, graph, Y, etas=None): r""" Description ----------- Attention forward. Will update ``graph.edata['w']`` and ``graph.ndata['deg']``. Parameters ---------- graph : DGLGraph The graph. Y : torch.Tensor The feature to compute attention. etas : float The weight of each dimension. If ``None``, then weight of each dimension is 1. Default: ``None``. Returns ------- DGLGraph The graph. """ if etas is not None: Y = Y * etas.view(-1) # computing edge distance graph.srcdata["h"] = Y graph.srcdata["h_norm"] = (Y**2).sum(-1) graph.apply_edges(fn.u_dot_v("h", "h", "dot_")) graph.apply_edges(fn.u_add_v("h_norm", "h_norm", "norm_")) graph.edata["dot_"] = graph.edata["dot_"].view(-1) graph.edata["norm_"] = graph.edata["norm_"].view(-1) graph.edata["w"] = graph.edata["norm_"] - 2 * graph.edata["dot_"] # apply edge distance to get edge weight self.reweighting(graph) # update node degrees graph.update_all(fn.copy_e("w", "m"), fn.sum("m", "deg")) graph.ndata["deg"] = graph.ndata["deg"].view(-1) # attention dropout. the implementation can ensure the degrees do not change in expectation. # FIXME: consider if there is a better way if self.attn_dropout > 0: graph.edata["w"] = F.dropout( graph.edata["w"], self.attn_dropout, training=self.training ) return graph def normalized_AX(graph, X): """Y = D^{-1/2}AD^{-1/2}X""" Y = D_power_X(graph, X, -0.5) # Y = D^{-1/2}X Y = AX(graph, Y) # Y = AD^{-1/2}X Y = D_power_X(graph, Y, -0.5) # Y = D^{-1/2}AD^{-1/2}X return Y def AX(graph, X): """Y = AX""" graph.srcdata["h"] = X graph.update_all( fn.u_mul_e("h", "w", "m"), fn.sum("m", "h"), ) Y = graph.dstdata["h"] return Y def D_power_X(graph, X, power): """Y = D^{power}X""" degs = graph.ndata["deg"] norm = tc.pow(degs, power) Y = X * norm.view(X.size(0), 1) return Y def D_power_bias_X(graph, X, power, coeff, bias): """Y = (coeff*D + bias*I)^{power} X""" degs = graph.ndata["deg"] degs = coeff * degs + bias norm = tc.pow(degs, power) Y = X * norm.view(X.size(0), 1) return Y
[docs]class TWIRLSUnfoldingAndAttention(nn.Module): r""" Description ----------- Combine propagation and attention together. Parameters ---------- d : int Size of graph feature. alp : float Step size. :math:`\alpha` in ther paper. lam : int Coefficient of graph smooth term. :math:`\lambda` in ther paper. prop_step : int Number of propagation steps attn_aft : int Where to put attention layer. i.e. number of propagation steps before attention. If set to ``-1``, then no attention. tau : float The lower thresholding parameter. Correspond to :math:`\tau` in the paper. T : float The upper thresholding parameter. Correspond to :math:`T` in the paper. p : float Correspond to :math:`\rho` in the paper.. use_eta : bool If `True`, learn a weight vector for each dimension when doing attention. init_att : bool If ``True``, add an extra attention layer before propagation. attn_dropout : float the dropout rate of attention value. Default: ``0.0``. precond : bool If ``True``, use pre-conditioned & reparameterized version propagation (eq.28), else use normalized laplacian (eq.30). Example ------- >>> import dgl >>> from dgl.nn import TWIRLSUnfoldingAndAttention >>> import torch as th >>> g = dgl.graph(([0, 1, 2, 3, 2, 5], [1, 2, 3, 4, 0, 3])).add_self_loop() >>> feat = th.ones(6,5) >>> prop = TWIRLSUnfoldingAndAttention(10, 1, 1, prop_step=3) >>> res = prop(g,feat) >>> res tensor([[2.5000, 2.5000, 2.5000, 2.5000, 2.5000], [2.5000, 2.5000, 2.5000, 2.5000, 2.5000], [2.5000, 2.5000, 2.5000, 2.5000, 2.5000], [3.7656, 3.7656, 3.7656, 3.7656, 3.7656], [2.5217, 2.5217, 2.5217, 2.5217, 2.5217], [4.0000, 4.0000, 4.0000, 4.0000, 4.0000]]) """ def __init__( self, d, alp, lam, prop_step, attn_aft=-1, tau=0.2, T=-1, p=1, use_eta=False, init_att=False, attn_dropout=0, precond=True, ): super().__init__() self.d = d self.alp = alp if alp > 0 else 1 / (lam + 1) # automatic set alpha self.lam = lam self.tau = tau self.p = p self.prop_step = prop_step self.attn_aft = attn_aft self.use_eta = use_eta self.init_att = init_att prop_method = Propagate if precond else PropagateNoPrecond self.prop_layers = nn.ModuleList( [prop_method() for _ in range(prop_step)] ) self.init_attn = ( Attention(tau, T, p, attn_dropout) if self.init_att else None ) self.attn_layer = ( Attention(tau, T, p, attn_dropout) if self.attn_aft >= 0 else None ) self.etas = nn.Parameter(tc.ones(d)) if self.use_eta else None
[docs] def forward(self, g, X): r""" Description ----------- Compute forward pass of propagation & attention. Parameters ---------- g : DGLGraph The graph. X : torch.Tensor Init features. Returns ------- torch.Tensor The graph. """ Y = X g.edata["w"] = tc.ones(g.num_edges(), 1, device=g.device) g.ndata["deg"] = g.in_degrees().to(X) if self.init_att: g = self.init_attn(g, Y, self.etas) for k, layer in enumerate(self.prop_layers): # do unfolding Y = layer(g, Y, X, self.alp, self.lam) # do attention at certain layer if k == self.attn_aft - 1: g = self.attn_layer(g, Y, self.etas) return Y
class MLP(nn.Module): r""" Description ----------- An MLP module. Parameters ---------- input_d : int Number of input features. output_d : int Number of output features. hidden_d : int Size of hidden layers. num_layers : int Number of mlp layers. dropout : float The dropout rate in mlp layers. norm : str The type of norm layers inside mlp layers. Can be ``'batch'``, ``'layer'`` or ``'none'``. init_activate : bool If add a relu at the beginning. """ def __init__( self, input_d, hidden_d, output_d, num_layers, dropout, norm, init_activate, ): super().__init__() self.init_activate = init_activate self.norm = norm self.dropout = dropout self.layers = nn.ModuleList([]) if num_layers == 1: self.layers.append(nn.Linear(input_d, output_d)) elif num_layers > 1: self.layers.append(nn.Linear(input_d, hidden_d)) for _ in range(num_layers - 2): self.layers.append(nn.Linear(hidden_d, hidden_d)) self.layers.append(nn.Linear(hidden_d, output_d)) # how many norm layers we have self.norm_cnt = num_layers - 1 + int(init_activate) if norm == "batch": self.norms = nn.ModuleList( [nn.BatchNorm1d(hidden_d) for _ in range(self.norm_cnt)] ) elif norm == "layer": self.norms = nn.ModuleList( [nn.LayerNorm(hidden_d) for _ in range(self.norm_cnt)] ) self.reset_params() def reset_params(self): """reset mlp parameters using xavier_norm""" for layer in self.layers: nn.init.xavier_normal_(layer.weight.data) nn.init.constant_(layer.bias.data, 0) def activate(self, x): """do normlaization and activation""" if self.norm != "none": x = self.norms[self.cur_norm_idx](x) # use the last norm layer self.cur_norm_idx += 1 x = F.relu(x) x = F.dropout(x, self.dropout, training=self.training) return x def forward(self, x): """The forward pass of mlp.""" self.cur_norm_idx = 0 if self.init_activate: x = self.activate(x) for i, layer in enumerate(self.layers): x = layer(x) if i != len(self.layers) - 1: # do not activate in the last layer x = self.activate(x) return x