GlobalAttentionPooling¶
-
class
dgl.nn.mxnet.glob.
GlobalAttentionPooling
(gate_nn, feat_nn=None)[source]¶ Bases:
mxnet.gluon.block.Block
Global Attention Pooling layer from Gated Graph Sequence Neural Networks
\[r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate} \left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)\]- Parameters
gate_nn (gluon.nn.Block) – A neural network that computes attention scores for each feature.
feat_nn (gluon.nn.Block, optional) – A neural network applied to each feature before combining them with attention scores.
-
forward
(graph, feat)[source]¶ Compute global attention pooling.
- Parameters
graph (DGLGraph) – The graph.
feat (mxnet.NDArray) – The input node feature with shape \((N, D)\) where \(N\) is the number of nodes in the graph.
- Returns
The output feature with shape \((B, D)\), where \(B\) refers to the batch size.
- Return type
mxnet.NDArray