# GroupRevRes¶

class dgl.nn.pytorch.conv.GroupRevRes(gnn_module, groups=2)[source]

Bases: torch.nn.modules.module.Module

Grouped reversible residual connections for GNNs, as introduced in Training Graph Neural Networks with 1000 Layers

It uniformly partitions an input node feature $$X$$ into $$C$$ groups $$X_1, X_2, \cdots, X_C$$ across the channel dimension. Besides, it makes $$C$$ copies of the input GNN module $$f_{w1}, \cdots, f_{wC}$$. In the forward pass, each GNN module only takes the corresponding group of node features.

The output node representations $$X^{'}$$ are computed as follows.

\begin{align}\begin{aligned}X_0^{'} = \sum_{i=2}^{C}X_i\\X_i^{'} = f_{wi}(X_{i-1}^{'}, g, U) + X_i, i\in\{1,\cdots,C\}\\X^{'} = X_1^{'} \, \Vert \, \ldots \, \Vert \, X_C^{'}\end{aligned}\end{align}

where $$g$$ is the input graph, $$U$$ is arbitrary additional input arguments like edge features, and $$\, \Vert \,$$ is concatenation.

Parameters
• gnn_module (nn.Module) – GNN module for message passing. GroupRevRes will clone the module for groups-1 number of times, yielding groups copies in total. The input and output node representation size need to be the same. Its forward function needs to take a DGLGraph and the associated input node features in order, optionally followed by additional arguments like edge features.

• groups (int, optional) – The number of groups.

Examples

>>> import dgl
>>> import torch
>>> import torch.nn as nn
>>> from dgl.nn import GraphConv, GroupRevRes

>>> class GNNLayer(nn.Module):
...     def __init__(self, feats, dropout=0.2):
...         super(GNNLayer, self).__init__()
...         # Use BatchNorm and dropout to prevent gradient vanishing
...         # In particular if you use a large number of GNN layers
...         self.norm = nn.BatchNorm1d(feats)
...         self.conv = GraphConv(feats, feats)
...         self.dropout = nn.Dropout(dropout)
...
...     def forward(self, g, x):
...         x = self.norm(x)
...         x = self.dropout(x)
...         return self.conv(g, x)

>>> num_nodes = 5
>>> num_edges = 20
>>> feats = 32
>>> groups = 2
>>> g = dgl.rand_graph(num_nodes, num_edges)
>>> x = torch.randn(num_nodes, feats)
>>> conv = GNNLayer(feats // groups)
>>> model = GroupRevRes(conv, groups)
>>> out = model(g, x)

forward(g, x, *args)[source]

Apply the GNN module with grouped reversible residual connection.

Parameters
• g (DGLGraph) – The graph.

• x (torch.Tensor) – The input feature of shape $$(N, D_{in})$$, where $$D_{in}$$ is size of input feature, $$N$$ is the number of nodes.

• args – Additional arguments to pass to gnn_module.

Returns

The output feature of shape $$(N, D_{in})$$.

Return type

torch.Tensor