Set2Setο
- class dgl.nn.pytorch.glob.Set2Set(input_dim, n_iters, n_layers)[source]ο
Bases:
Module
Set2Set operator from Order Matters: Sequence to sequence for sets
For each individual graph in the batch, set2set computes
\[ \begin{align}\begin{aligned}q_t &= \mathrm{LSTM} (q^*_{t-1})\\\alpha_{i,t} &= \mathrm{softmax}(x_i \cdot q_t)\\r_t &= \sum_{i=1}^N \alpha_{i,t} x_i\\q^*_t &= q_t \Vert r_t\end{aligned}\end{align} \]for this graph.
- Parameters:
Examples
The following example uses PyTorch backend.
>>> import dgl >>> import torch as th >>> from dgl.nn import Set2Set >>> >>> g1 = dgl.rand_graph(3, 4) # g1 is a random graph with 3 nodes and 4 edges >>> g1_node_feats = th.rand(3, 5) # feature size is 5 >>> g1_node_feats tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637], [0.8137, 0.8938, 0.8377, 0.4249, 0.6118], [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]]) >>> >>> g2 = dgl.rand_graph(4, 6) # g2 is a random graph with 4 nodes and 6 edges >>> g2_node_feats = th.rand(4, 5) # feature size is 5 >>> g2_node_feats tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658], [0.5278, 0.6365, 0.9990, 0.2351, 0.8945], [0.3134, 0.0580, 0.4349, 0.7949, 0.3891], [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]]) >>> >>> s2s = Set2Set(5, 2, 1) # create a Set2Set layer(n_iters=2, n_layers=1)
Case 1: Input a single graph
>>> s2s(g1, g1_node_feats) tensor([[-0.0235, -0.2291, 0.2654, 0.0376, 0.1349, 0.7560, 0.5822, 0.8199, 0.5960, 0.4760]], grad_fn=<CatBackward>)
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphsβ node features into one tensor.
>>> batch_g = dgl.batch([g1, g2]) >>> batch_f = th.cat([g1_node_feats, g2_node_feats], 0) >>> >>> s2s(batch_g, batch_f) tensor([[-0.0235, -0.2291, 0.2654, 0.0376, 0.1349, 0.7560, 0.5822, 0.8199, 0.5960, 0.4760], [-0.0483, -0.2010, 0.2324, 0.0145, 0.1361, 0.2703, 0.3078, 0.5529, 0.6876, 0.6399]], grad_fn=<CatBackward>)
Notes
Set2Set is widely used in molecular property predictions, see dgl-lifesciβs MPNN example on how to use DGLβs Set2Set layer in graph property prediction applications.
- forward(graph, feat)[source]ο
Compute set2set pooling.
- Parameters:
graph (DGLGraph) β The input graph.
feat (torch.Tensor) β The input feature with shape \((N, D)\) where \(N\) is the number of nodes in the graph, and \(D\) means the size of features.
- Returns:
The output feature with shape \((B, D)\), where \(B\) refers to the batch size, and \(D\) means the size of features.
- Return type:
torch.Tensor