GroupRevRes¶
-
class
dgl.nn.pytorch.conv.
GroupRevRes
(gnn_module, groups=2)[source]¶ Bases:
torch.nn.modules.module.Module
Grouped reversible residual connections for GNNs, as introduced in Training Graph Neural Networks with 1000 Layers
It uniformly partitions an input node feature \(X\) into \(C\) groups \(X_1, X_2, \cdots, X_C\) across the channel dimension. Besides, it makes \(C\) copies of the input GNN module \(f_{w1}, \cdots, f_{wC}\). In the forward pass, each GNN module only takes the corresponding group of node features.
The output node representations \(X^{'}\) are computed as follows.
\[ \begin{align}\begin{aligned}X_0^{'} = \sum_{i=2}^{C}X_i\\X_i^{'} = f_{wi}(X_{i-1}^{'}, g, U) + X_i, i\in\{1,\cdots,C\}\\X^{'} = X_1^{'} \, \Vert \, \ldots \, \Vert \, X_C^{'}\end{aligned}\end{align} \]where \(g\) is the input graph, \(U\) is arbitrary additional input arguments like edge features, and \(\, \Vert \,\) is concatenation.
- Parameters
gnn_module (nn.Module) – GNN module for message passing.
GroupRevRes
will clone the module forgroups
-1 number of times, yieldinggroups
copies in total. The input and output node representation size need to be the same. Its forward function needs to take a DGLGraph and the associated input node features in order, optionally followed by additional arguments like edge features.groups (int, optional) – The number of groups.
Examples
>>> import dgl >>> import torch >>> import torch.nn as nn >>> from dgl.nn import GraphConv, GroupRevRes
>>> class GNNLayer(nn.Module): ... def __init__(self, feats, dropout=0.2): ... super(GNNLayer, self).__init__() ... # Use BatchNorm and dropout to prevent gradient vanishing ... # In particular if you use a large number of GNN layers ... self.norm = nn.BatchNorm1d(feats) ... self.conv = GraphConv(feats, feats) ... self.dropout = nn.Dropout(dropout) ... ... def forward(self, g, x): ... x = self.norm(x) ... x = self.dropout(x) ... return self.conv(g, x)
>>> num_nodes = 5 >>> num_edges = 20 >>> feats = 32 >>> groups = 2 >>> g = dgl.rand_graph(num_nodes, num_edges) >>> x = torch.randn(num_nodes, feats) >>> conv = GNNLayer(feats // groups) >>> model = GroupRevRes(conv, groups) >>> out = model(g, x)
-
forward
(g, x, *args)[source]¶ Apply the GNN module with grouped reversible residual connection.
- Parameters
g (DGLGraph) – The graph.
x (torch.Tensor) – The input feature of shape \((N, D_{in})\), where \(D_{in}\) is size of input feature, \(N\) is the number of nodes.
args – Additional arguments to pass to
gnn_module
.
- Returns
The output feature of shape \((N, D_{in})\).
- Return type
torch.Tensor