SortPoolingΒΆ
-
class
dgl.nn.mxnet.glob.
SortPooling
(k)[source]ΒΆ Bases:
mxnet.gluon.block.Block
Pooling layer from An End-to-End Deep Learning Architecture for Graph Classification
- Parameters
k (int) β The number of nodes to hold for each graph.
-
forward
(graph, feat)[source]ΒΆ Compute sort pooling.
- Parameters
graph (DGLGraph) β The graph.
feat (mxnet.NDArray) β The input node feature with shape \((N, D)\) where \(N\) is the number of nodes in the graph.
- Returns
The output feature with shape \((B, k * D)\), where \(B\) refers to the batch size.
- Return type
mxnet.NDArray