# SortPooling¶

class dgl.nn.pytorch.glob.SortPooling(k)[source]

Bases: torch.nn.modules.module.Module

Sort Pooling from An End-to-End Deep Learning Architecture for Graph Classification

It first sorts the node features in ascending order along the feature dimension, and selects the sorted features of top-k nodes (ranked by the largest value of each node).

Parameters

k (int) – The number of nodes to hold for each graph.

Notes

Input: Could be one graph, or a batch of graphs. If using a batch of graphs, make sure nodes in all graphs have the same feature size, and concatenate nodes’ feature together as the input.

Examples

>>> import dgl
>>> import torch as th
>>> from dgl.nn import SortPooling
>>>
>>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5)  # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
[0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
[0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5)  # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
[0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
[0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
[0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> sortpool = SortPooling(k=2)  # create a sort pooling layer


Case 1: Input a single graph

>>> sortpool(g1, g1_node_feats)
tensor([[0.0699, 0.3637, 0.7567, 0.8948, 0.9137, 0.4755, 0.5197, 0.5725, 0.6825,
0.9030]])


Case 2: Input a batch of graphs

Build a batch of DGL graphs and concatenate all graphs’ node features into one tensor.

>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> sortpool(batch_g, batch_f)
tensor([[0.0699, 0.3637, 0.7567, 0.8948, 0.9137, 0.4755, 0.5197, 0.5725, 0.6825,
0.9030],
[0.2351, 0.5278, 0.6365, 0.8945, 0.9990, 0.2053, 0.2426, 0.4111, 0.5658,
0.9028]])

forward(graph, feat)[source]

Compute sort pooling.

Parameters
• graph (DGLGraph) – A DGLGraph or a batch of DGLGraphs.

• feat (torch.Tensor) – The input node feature with shape $$(N, D)$$, where $$N$$ is the number of nodes in the graph, and $$D$$ means the size of features.

Returns

The output feature with shape $$(B, k * D)$$, where $$B$$ refers to the batch size of input graphs.

Return type

torch.Tensor