Set2SetΒΆ

class dgl.nn.mxnet.glob.Set2Set(input_dim, n_iters, n_layers)[source]ΒΆ

Bases: mxnet.gluon.block.Block

Set2Set operator from Order Matters: Sequence to sequence for sets

For each individual graph in the batch, set2set computes

\[ \begin{align}\begin{aligned}q_t &= \mathrm{LSTM} (q^*_{t-1})\\\alpha_{i,t} &= \mathrm{softmax}(x_i \cdot q_t)\\r_t &= \sum_{i=1}^N \alpha_{i,t} x_i\\q^*_t &= q_t \Vert r_t\end{aligned}\end{align} \]

for this graph.

Parameters
  • input_dim (int) – Size of each input sample

  • n_iters (int) – Number of iterations.

  • n_layers (int) – Number of recurrent layers.

forward(graph, feat)[source]ΒΆ

Compute set2set pooling.

Parameters
  • graph (DGLGraph) – The graph.

  • feat (mxnet.NDArray) – The input node feature with shape \((N, D)\) where \(N\) is the number of nodes in the graph.

Returns

The output feature with shape \((B, D)\), where \(B\) refers to the batch size.

Return type

mxnet.NDArray