Source code for dgl.nn.pytorch.conv.grouprevres

"""Torch module for grouped reversible residual connections for GNNs"""
# pylint: disable= no-member, arguments-differ, invalid-name, C0116, R1728
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn


class InvertibleCheckpoint(torch.autograd.Function):
    r"""Extension of torch.autograd"""

    @staticmethod
    def forward(ctx, fn, fn_inverse, num_inputs, *inputs_and_weights):
        ctx.fn = fn
        ctx.fn_inverse = fn_inverse
        ctx.weights = inputs_and_weights[num_inputs:]
        inputs = inputs_and_weights[:num_inputs]
        ctx.input_requires_grad = []

        with torch.no_grad():
            # Make a detached copy, which shares the storage
            x = []
            for element in inputs:
                if isinstance(element, torch.Tensor):
                    x.append(element.detach())
                    ctx.input_requires_grad.append(element.requires_grad)
                else:
                    x.append(element)
                    ctx.input_requires_grad.append(None)
            # Detach the output, which then allows discarding the intermediary results
            outputs = ctx.fn(*x).detach_()

        # clear memory of input node features
        inputs[1].storage().resize_(0)

        # store for backward pass
        ctx.inputs = [inputs]
        ctx.outputs = [outputs]

        return outputs

    @staticmethod
    def backward(ctx, *grad_outputs):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "InvertibleCheckpoint is not compatible with .grad(), \
                               please use .backward() if possible"
            )
        # retrieve input and output tensor nodes
        if len(ctx.outputs) == 0:
            raise RuntimeError(
                "Trying to perform backward on the InvertibleCheckpoint \
                               for more than once."
            )
        inputs = ctx.inputs.pop()
        outputs = ctx.outputs.pop()

        # reconstruct input node features
        with torch.no_grad():
            # inputs[0] is DGLGraph and inputs[1] is input node features
            inputs_inverted = ctx.fn_inverse(
                *((inputs[0], outputs) + inputs[2:])
            )
            # clear memory of outputs
            outputs.storage().resize_(0)

            x = inputs[1]
            x.storage().resize_(int(np.prod(x.size())))
            x.set_(inputs_inverted)

        # compute gradients
        with torch.set_grad_enabled(True):
            detached_inputs = []
            for i, element in enumerate(inputs):
                if isinstance(element, torch.Tensor):
                    element = element.detach()
                    element.requires_grad = ctx.input_requires_grad[i]
                detached_inputs.append(element)

            detached_inputs = tuple(detached_inputs)
            temp_output = ctx.fn(*detached_inputs)

        filtered_detached_inputs = tuple(
            filter(
                lambda x: getattr(x, "requires_grad", False), detached_inputs
            )
        )
        gradients = torch.autograd.grad(
            outputs=(temp_output,),
            inputs=filtered_detached_inputs + ctx.weights,
            grad_outputs=grad_outputs,
        )

        input_gradients = []
        i = 0
        for rg in ctx.input_requires_grad:
            if rg:
                input_gradients.append(gradients[i])
                i += 1
            else:
                input_gradients.append(None)

        gradients = tuple(input_gradients) + gradients[-len(ctx.weights) :]

        return (None, None, None) + gradients


[docs]class GroupRevRes(nn.Module): r"""Grouped reversible residual connections for GNNs, as introduced in `Training Graph Neural Networks with 1000 Layers <https://arxiv.org/abs/2106.07476>`__ It uniformly partitions an input node feature :math:`X` into :math:`C` groups :math:`X_1, X_2, \cdots, X_C` across the channel dimension. Besides, it makes :math:`C` copies of the input GNN module :math:`f_{w1}, \cdots, f_{wC}`. In the forward pass, each GNN module only takes the corresponding group of node features. The output node representations :math:`X^{'}` are computed as follows. .. math:: X_0^{'} = \sum_{i=2}^{C}X_i X_i^{'} = f_{wi}(X_{i-1}^{'}, g, U) + X_i, i\in\{1,\cdots,C\} X^{'} = X_1^{'} \, \Vert \, \ldots \, \Vert \, X_C^{'} where :math:`g` is the input graph, :math:`U` is arbitrary additional input arguments like edge features, and :math:`\, \Vert \,` is concatenation. Parameters ---------- gnn_module : nn.Module GNN module for message passing. :attr:`GroupRevRes` will clone the module for :attr:`groups`-1 number of times, yielding :attr:`groups` copies in total. The input and output node representation size need to be the same. Its forward function needs to take a DGLGraph and the associated input node features in order, optionally followed by additional arguments like edge features. groups : int, optional The number of groups. Examples -------- >>> import dgl >>> import torch >>> import torch.nn as nn >>> from dgl.nn import GraphConv, GroupRevRes >>> class GNNLayer(nn.Module): ... def __init__(self, feats, dropout=0.2): ... super(GNNLayer, self).__init__() ... # Use BatchNorm and dropout to prevent gradient vanishing ... # In particular if you use a large number of GNN layers ... self.norm = nn.BatchNorm1d(feats) ... self.conv = GraphConv(feats, feats) ... self.dropout = nn.Dropout(dropout) ... ... def forward(self, g, x): ... x = self.norm(x) ... x = self.dropout(x) ... return self.conv(g, x) >>> num_nodes = 5 >>> num_edges = 20 >>> feats = 32 >>> groups = 2 >>> g = dgl.rand_graph(num_nodes, num_edges) >>> x = torch.randn(num_nodes, feats) >>> conv = GNNLayer(feats // groups) >>> model = GroupRevRes(conv, groups) >>> out = model(g, x) """ def __init__(self, gnn_module, groups=2): super(GroupRevRes, self).__init__() self.gnn_modules = nn.ModuleList() for i in range(groups): if i == 0: self.gnn_modules.append(gnn_module) else: self.gnn_modules.append(deepcopy(gnn_module)) self.groups = groups def _forward(self, g, x, *args): xs = torch.chunk(x, self.groups, dim=-1) if len(args) == 0: args_chunks = [()] * self.groups else: chunked_args = list( map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args) ) args_chunks = list(zip(*chunked_args)) y_in = sum(xs[1:]) ys = [] for i in range(self.groups): y_in = xs[i] + self.gnn_modules[i](g, y_in, *args_chunks[i]) ys.append(y_in) out = torch.cat(ys, dim=-1) return out def _inverse(self, g, y, *args): ys = torch.chunk(y, self.groups, dim=-1) if len(args) == 0: args_chunks = [()] * self.groups else: chunked_args = list( map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args) ) args_chunks = list(zip(*chunked_args)) xs = [] for i in range(self.groups - 1, -1, -1): if i != 0: y_in = ys[i - 1] else: y_in = sum(xs) x = ys[i] - self.gnn_modules[i](g, y_in, *args_chunks[i]) xs.append(x) x = torch.cat(xs[::-1], dim=-1) return x
[docs] def forward(self, g, x, *args): r"""Apply the GNN module with grouped reversible residual connection. Parameters ---------- g : DGLGraph The graph. x : torch.Tensor The input feature of shape :math:`(N, D_{in})`, where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. args Additional arguments to pass to :attr:`gnn_module`. Returns ------- torch.Tensor The output feature of shape :math:`(N, D_{in})`. """ args = (g, x) + args y = InvertibleCheckpoint.apply( self._forward, self._inverse, len(args), *(args + tuple([p for p in self.parameters() if p.requires_grad])) ) return y