SumPooling

class dgl.nn.pytorch.glob.SumPooling[source]

Bases: torch.nn.modules.module.Module

Apply sum pooling over the nodes in a graph.

\[r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k\]

Notes

Input: Could be one graph, or a batch of graphs. If using a batch of graphs, make sure nodes in all graphs have the same feature size, and concatenate nodes’ feature together as the input.

Examples

The following example uses PyTorch backend.

>>> import dgl
>>> import torch as th
>>> from dgl.nn import SumPooling
>>>
>>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5)  # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
        [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
        [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5)  # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
        [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
        [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
        [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> sumpool = SumPooling()  # create a sum pooling layer

Case 1: Input a single graph

>>> sumpool(g1, g1_node_feats)
tensor([[2.2282, 1.8667, 2.4338, 1.7540, 1.4511]])

Case 2: Input a batch of graphs

Build a batch of DGL graphs and concatenate all graphs’ node features into one tensor.

>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> sumpool(batch_g, batch_f)
tensor([[2.2282, 1.8667, 2.4338, 1.7540, 1.4511],
        [1.0608, 1.2080, 2.1780, 2.7849, 2.5420]])
forward(graph, feat)[source]

Compute sum pooling.

Parameters
  • graph (DGLGraph) – a DGLGraph or a batch of DGLGraphs

  • feat (torch.Tensor) – The input feature with shape \((N, D)\), where \(N\) is the number of nodes in the graph, and \(D\) means the size of features.

Returns

The output feature with shape \((B, D)\), where \(B\) refers to the batch size of input graphs.

Return type

torch.Tensor