Sequential

class dgl.nn.pytorch.utils.Sequential(*args)[source]

Bases: torch.nn.modules.container.Sequential

A sequential container for stacking graph neural network modules

DGL supports two modes: sequentially apply GNN modules on 1) the same graph or 2) a list of given graphs. In the second case, the number of graphs equals the number of modules inside this container.

Parameters

*args – Sub-modules of torch.nn.Module that will be added to the container in the order by which they are passed in the constructor.

Examples

The following example uses PyTorch backend.

Mode 1: sequentially apply GNN modules on the same graph

>>> import torch
>>> import dgl
>>> import torch.nn as nn
>>> import dgl.function as fn
>>> from dgl.nn.pytorch import Sequential
>>> class ExampleLayer(nn.Module):
>>>     def __init__(self):
>>>         super().__init__()
>>>     def forward(self, graph, n_feat, e_feat):
>>>         with graph.local_scope():
>>>             graph.ndata['h'] = n_feat
>>>             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>>             n_feat += graph.ndata['h']
>>>             graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
>>>             e_feat += graph.edata['e']
>>>             return n_feat, e_feat
>>>
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
>>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
>>> n_feat = torch.rand(3, 4)
>>> e_feat = torch.rand(9, 4)
>>> net(g, n_feat, e_feat)
(tensor([[39.8597, 45.4542, 25.1877, 30.8086],
         [40.7095, 45.3985, 25.4590, 30.0134],
         [40.7894, 45.2556, 25.5221, 30.4220]]),
 tensor([[80.3772, 89.7752, 50.7762, 60.5520],
         [80.5671, 89.3736, 50.6558, 60.6418],
         [80.4620, 89.5142, 50.3643, 60.3126],
         [80.4817, 89.8549, 50.9430, 59.9108],
         [80.2284, 89.6954, 50.0448, 60.1139],
         [79.7846, 89.6882, 50.5097, 60.6213],
         [80.2654, 90.2330, 50.2787, 60.6937],
         [80.3468, 90.0341, 50.2062, 60.2659],
         [80.0556, 90.2789, 50.2882, 60.5845]]))

Mode 2: sequentially apply GNN modules on different graphs

>>> import torch
>>> import dgl
>>> import torch.nn as nn
>>> import dgl.function as fn
>>> import networkx as nx
>>> from dgl.nn.pytorch import Sequential
>>> class ExampleLayer(nn.Module):
>>>     def __init__(self):
>>>         super().__init__()
>>>     def forward(self, graph, n_feat):
>>>         with graph.local_scope():
>>>             graph.ndata['h'] = n_feat
>>>             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>>             n_feat += graph.ndata['h']
>>>             return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)
>>>
>>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
>>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
>>> g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))
>>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
>>> n_feat = torch.rand(32, 4)
>>> net([g1, g2, g3], n_feat)
tensor([[209.6221, 225.5312, 193.8920, 220.1002],
        [250.0169, 271.9156, 240.2467, 267.7766],
        [220.4007, 239.7365, 213.8648, 234.9637],
        [196.4630, 207.6319, 184.2927, 208.7465]])
forward(graph, *feats)[source]

Sequentially apply modules to the input.

Parameters
  • graph (DGLGraph or list of DGLGraphs) – The graph(s) to apply modules on.

  • *feats – Input features. The output of the \(i\)-th module should match the input of the \((i+1)\)-th module in the sequential.