SortPooling

class dgl.nn.pytorch.glob.SortPooling(k)[source]

Bases: torch.nn.modules.module.Module

Sort Pooling from An End-to-End Deep Learning Architecture for Graph Classification

It first sorts the node features in ascending order along the feature dimension, and selects the sorted features of top-k nodes (ranked by the largest value of each node).

Parameters

k (int) – The number of nodes to hold for each graph.

Notes

Input: Could be one graph, or a batch of graphs. If using a batch of graphs, make sure nodes in all graphs have the same feature size, and concatenate nodes’ feature together as the input.

Examples

>>> import dgl
>>> import torch as th
>>> from dgl.nn import SortPooling
>>>
>>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5)  # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
        [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
        [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5)  # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
        [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
        [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
        [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> sortpool = SortPooling(k=2)  # create a sort pooling layer

Case 1: Input a single graph

>>> sortpool(g1, g1_node_feats)
tensor([[0.0699, 0.3637, 0.7567, 0.8948, 0.9137, 0.4755, 0.5197, 0.5725, 0.6825,
         0.9030]])

Case 2: Input a batch of graphs

Build a batch of DGL graphs and concatenate all graphs’ node features into one tensor.

>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> sortpool(batch_g, batch_f)
tensor([[0.0699, 0.3637, 0.7567, 0.8948, 0.9137, 0.4755, 0.5197, 0.5725, 0.6825,
         0.9030],
        [0.2351, 0.5278, 0.6365, 0.8945, 0.9990, 0.2053, 0.2426, 0.4111, 0.5658,
         0.9028]])
forward(graph, feat)[source]

Compute sort pooling.

Parameters
  • graph (DGLGraph) – A DGLGraph or a batch of DGLGraphs.

  • feat (torch.Tensor) – The input node feature with shape \((N, D)\), where \(N\) is the number of nodes in the graph, and \(D\) means the size of features.

Returns

The output feature with shape \((B, k * D)\), where \(B\) refers to the batch size of input graphs.

Return type

torch.Tensor