class dgl.nn.pytorch.conv.GroupRevRes(gnn_module, groups=2)[source]

Bases: torch.nn.modules.module.Module

Grouped reversible residual connections for GNNs, as introduced in Training Graph Neural Networks with 1000 Layers

It uniformly partitions an input node feature \(X\) into \(C\) groups \(X_1, X_2, \cdots, X_C\) across the channel dimension. Besides, it makes \(C\) copies of the input GNN module \(f_{w1}, \cdots, f_{wC}\). In the forward pass, each GNN module only takes the corresponding group of node features.

The output node representations \(X^{'}\) are computed as follows.

\[ \begin{align}\begin{aligned}X_0^{'} = \sum_{i=2}^{C}X_i\\X_i^{'} = f_{wi}(X_{i-1}^{'}, g, U) + X_i, i\in\{1,\cdots,C\}\\X^{'} = X_1^{'} \, \Vert \, \ldots \, \Vert \, X_C^{'}\end{aligned}\end{align} \]

where \(g\) is the input graph, \(U\) is arbitrary additional input arguments like edge features, and \(\, \Vert \,\) is concatenation.

  • gnn_module (nn.Module) – GNN module for message passing. GroupRevRes will clone the module for groups-1 number of times, yielding groups copies in total. The input and output node representation size need to be the same. Its forward function needs to take a DGLGraph and the associated input node features in order, optionally followed by additional arguments like edge features.

  • groups (int, optional) – The number of groups.


>>> import dgl
>>> import torch
>>> import torch.nn as nn
>>> from dgl.nn import GraphConv, GroupRevRes
>>> class GNNLayer(nn.Module):
...     def __init__(self, feats, dropout=0.2):
...         super(GNNLayer, self).__init__()
...         # Use BatchNorm and dropout to prevent gradient vanishing
...         # In particular if you use a large number of GNN layers
...         self.norm = nn.BatchNorm1d(feats)
...         self.conv = GraphConv(feats, feats)
...         self.dropout = nn.Dropout(dropout)
...     def forward(self, g, x):
...         x = self.norm(x)
...         x = self.dropout(x)
...         return self.conv(g, x)
>>> num_nodes = 5
>>> num_edges = 20
>>> feats = 32
>>> groups = 2
>>> g = dgl.rand_graph(num_nodes, num_edges)
>>> x = torch.randn(num_nodes, feats)
>>> conv = GNNLayer(feats // groups)
>>> model = GroupRevRes(conv, groups)
>>> out = model(g, x)
forward(g, x, *args)[source]

Apply the GNN module with grouped reversible residual connection.

  • g (DGLGraph) – The graph.

  • x (torch.Tensor) – The input feature of shape \((N, D_{in})\), where \(D_{in}\) is size of input feature, \(N\) is the number of nodes.

  • args – Additional arguments to pass to gnn_module.


The output feature of shape \((N, D_{in})\).

Return type