SetTransformerDecoder¶
-
class
dgl.nn.pytorch.glob.
SetTransformerDecoder
(d_model, num_heads, d_head, d_ff, n_layers, k, dropouth=0.0, dropouta=0.0)[source]¶ Bases:
torch.nn.modules.module.Module
The Decoder module from Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks
- Parameters
d_model (int) – Hidden size of the model.
num_heads (int) – The number of heads.
d_head (int) – Hidden size of each head.
d_ff (int) – Kernel size in FFN (Positionwise Feed-Forward Network) layer.
n_layers (int) – The number of layers.
k (int) – The number of seed vectors in PMA (Pooling by Multihead Attention) layer.
dropouth (float) – Dropout rate of each sublayer.
dropouta (float) – Dropout rate of attention heads.
Examples
>>> import dgl >>> import torch as th >>> from dgl.nn import SetTransformerDecoder >>> >>> g1 = dgl.rand_graph(3, 4) # g1 is a random graph with 3 nodes and 4 edges >>> g1_node_feats = th.rand(3, 5) # feature size is 5 >>> g1_node_feats tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637], [0.8137, 0.8938, 0.8377, 0.4249, 0.6118], [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]]) >>> >>> g2 = dgl.rand_graph(4, 6) # g2 is a random graph with 4 nodes and 6 edges >>> g2_node_feats = th.rand(4, 5) # feature size is 5 >>> g2_node_feats tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658], [0.5278, 0.6365, 0.9990, 0.2351, 0.8945], [0.3134, 0.0580, 0.4349, 0.7949, 0.3891], [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]]) >>> >>> set_trans_dec = SetTransformerDecoder(5, 4, 4, 20, 1, 3) # define the layer
Case 1: Input a single graph
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs’ node features into one tensor.
>>> batch_g = dgl.batch([g1, g2]) >>> batch_f = th.cat([g1_node_feats, g2_node_feats]) >>> >>> set_trans_dec(batch_g, batch_f) tensor([[-0.5538, 1.8726, -1.0470, 0.0276, -0.2994, -0.6317, 1.6754, -1.3189, 0.2291, 0.0461, -0.4042, 0.8387, -1.7091, 1.0845, 0.1902], [-0.5511, 1.8869, -1.0156, 0.0028, -0.3231, -0.6305, 1.6845, -1.3105, 0.2136, 0.0428, -0.3820, 0.8043, -1.7138, 1.1126, 0.1789]], grad_fn=<ViewBackward>)
See also
-
forward
(graph, feat)[source]¶ Compute the decoder part of Set Transformer.
- Parameters
graph (DGLGraph) – The input graph.
feat (torch.Tensor) – The input feature with shape (N,D), where N is the number of nodes in the graph, and D means the size of features.
- Returns
The output feature with shape (B,D), where B refers to the batch size.
- Return type
torch.Tensor