GlobalAttentionPoolingΒΆ

class dgl.nn.mxnet.glob.GlobalAttentionPooling(gate_nn, feat_nn=None)[source]ΒΆ

Bases: mxnet.gluon.block.Block

Global Attention Pooling layer from Gated Graph Sequence Neural Networks

\[r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate} \left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)\]
Parameters
  • gate_nn (gluon.nn.Block) – A neural network that computes attention scores for each feature.

  • feat_nn (gluon.nn.Block, optional) – A neural network applied to each feature before combining them with attention scores.

forward(graph, feat)[source]ΒΆ

Compute global attention pooling.

Parameters
  • graph (DGLGraph) – The graph.

  • feat (mxnet.NDArray) – The input node feature with shape \((N, D)\) where \(N\) is the number of nodes in the graph.

Returns

The output feature with shape \((B, D)\), where \(B\) refers to the batch size.

Return type

mxnet.NDArray