"""Torch modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import torch as th
import torch.nn as nn
import numpy as np
from ...backend import pytorch as F
from ...base import dgl_warning
from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
'GlobalAttentionPooling', 'Set2Set',
'SetTransformerEncoder', 'SetTransformerDecoder', 'WeightAndSum']
[docs]class SumPooling(nn.Module):
r"""Apply sum pooling over the nodes in a graph.
.. math::
r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k
Notes
-----
Input: Could be one graph, or a batch of graphs. If using a batch of graphs,
make sure nodes in all graphs have the same feature size, and concatenate
nodes' feature together as the input.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch as th
>>> from dgl.nn import SumPooling
>>>
>>> g1 = dgl.rand_graph(3, 4) # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5) # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
[0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
[0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6) # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5) # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
[0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
[0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
[0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> sumpool = SumPooling() # create a sum pooling layer
Case 1: Input a single graph
>>> sumpool(g1, g1_node_feats)
tensor([[2.2282, 1.8667, 2.4338, 1.7540, 1.4511]])
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.
>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> sumpool(batch_g, batch_f)
tensor([[2.2282, 1.8667, 2.4338, 1.7540, 1.4511],
[1.0608, 1.2080, 2.1780, 2.7849, 2.5420]])
"""
def __init__(self):
super(SumPooling, self).__init__()
[docs] def forward(self, graph, feat):
r"""
Compute sum pooling.
Parameters
----------
graph : DGLGraph
a DGLGraph or a batch of DGLGraphs
feat : torch.Tensor
The input feature with shape :math:`(N, D)`, where :math:`N` is the number
of nodes in the graph, and :math:`D` means the size of features.
Returns
-------
torch.Tensor
The output feature with shape :math:`(B, D)`, where :math:`B` refers to the
batch size of input graphs.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = sum_nodes(graph, 'h')
return readout
[docs]class AvgPooling(nn.Module):
r"""Apply average pooling over the nodes in a graph.
.. math::
r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k
Notes
-----
Input: Could be one graph, or a batch of graphs. If using a batch of graphs,
make sure nodes in all graphs have the same feature size, and concatenate
nodes' feature together as the input.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch as th
>>> from dgl.nn import AvgPooling
>>>
>>> g1 = dgl.rand_graph(3, 4) # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5) # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
[0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
[0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6) # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5) # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
[0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
[0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
[0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> avgpool = AvgPooling() # create an average pooling layer
Case 1: Input single graph
>>> avgpool(g1, g1_node_feats)
tensor([[0.7427, 0.6222, 0.8113, 0.5847, 0.4837]])
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs' note features into one tensor.
>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> avgpool(batch_g, batch_f)
tensor([[0.7427, 0.6222, 0.8113, 0.5847, 0.4837],
[0.2652, 0.3020, 0.5445, 0.6962, 0.6355]])
"""
def __init__(self):
super(AvgPooling, self).__init__()
[docs] def forward(self, graph, feat):
r"""
Compute average pooling.
Parameters
----------
graph : DGLGraph
A DGLGraph or a batch of DGLGraphs.
feat : torch.Tensor
The input feature with shape :math:`(N, D)`, where :math:`N` is the number
of nodes in the graph, and :math:`D` means the size of features.
Returns
-------
torch.Tensor
The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size of input graphs.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = mean_nodes(graph, 'h')
return readout
[docs]class MaxPooling(nn.Module):
r"""Apply max pooling over the nodes in a graph.
.. math::
r^{(i)} = \max_{k=1}^{N_i}\left( x^{(i)}_k \right)
Notes
-----
Input: Could be one graph, or a batch of graphs. If using a batch of graphs,
make sure nodes in all graphs have the same feature size, and concatenate
nodes' feature together as the input.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch as th
>>> from dgl.nn import MaxPooling
>>>
>>> g1 = dgl.rand_graph(3, 4) # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5) # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
[0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
[0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6) # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5) # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
[0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
[0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
[0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> maxpool = MaxPooling() # create a max pooling layer
Case 1: Input a single graph
>>> maxpool(g1, g1_node_feats)
tensor([[0.8948, 0.9030, 0.9137, 0.7567, 0.6118]])
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.
>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> maxpool(batch_g, batch_f)
tensor([[0.8948, 0.9030, 0.9137, 0.7567, 0.6118],
[0.5278, 0.6365, 0.9990, 0.9028, 0.8945]])
"""
def __init__(self):
super(MaxPooling, self).__init__()
[docs] def forward(self, graph, feat):
r"""Compute max pooling.
Parameters
----------
graph : DGLGraph
A DGLGraph or a batch of DGLGraphs.
feat : torch.Tensor
The input feature with shape :math:`(N, *)`, where
:math:`N` is the number of nodes in the graph.
Returns
-------
torch.Tensor
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = max_nodes(graph, 'h')
return readout
[docs]class SortPooling(nn.Module):
r"""Sort Pooling from `An End-to-End Deep Learning Architecture for Graph Classification
<https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf>`__
It first sorts the node features in ascending order along the feature dimension,
and selects the sorted features of top-k nodes (ranked by the largest value of each node).
Parameters
----------
k : int
The number of nodes to hold for each graph.
Notes
-----
Input: Could be one graph, or a batch of graphs. If using a batch of graphs,
make sure nodes in all graphs have the same feature size, and concatenate
nodes' feature together as the input.
Examples
--------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import SortPooling
>>>
>>> g1 = dgl.rand_graph(3, 4) # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5) # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
[0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
[0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6) # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5) # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
[0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
[0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
[0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> sortpool = SortPooling(k=2) # create a sort pooling layer
Case 1: Input a single graph
>>> sortpool(g1, g1_node_feats)
tensor([[0.0699, 0.3637, 0.7567, 0.8948, 0.9137, 0.4755, 0.5197, 0.5725, 0.6825,
0.9030]])
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.
>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> sortpool(batch_g, batch_f)
tensor([[0.0699, 0.3637, 0.7567, 0.8948, 0.9137, 0.4755, 0.5197, 0.5725, 0.6825,
0.9030],
[0.2351, 0.5278, 0.6365, 0.8945, 0.9990, 0.2053, 0.2426, 0.4111, 0.5658,
0.9028]])
"""
def __init__(self, k):
super(SortPooling, self).__init__()
self.k = k
[docs] def forward(self, graph, feat):
r"""
Compute sort pooling.
Parameters
----------
graph : DGLGraph
A DGLGraph or a batch of DGLGraphs.
feat : torch.Tensor
The input node feature with shape :math:`(N, D)`, where :math:`N` is the
number of nodes in the graph, and :math:`D` means the size of features.
Returns
-------
torch.Tensor
The output feature with shape :math:`(B, k * D)`, where :math:`B` refers
to the batch size of input graphs.
"""
with graph.local_scope():
# Sort the feature of each node in ascending order.
feat, _ = feat.sort(dim=-1)
graph.ndata['h'] = feat
# Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k, sortby=-1)[0].view(
-1, self.k * feat.shape[-1])
return ret
[docs]class GlobalAttentionPooling(nn.Module):
r"""Global Attention Pooling from `Gated Graph Sequence Neural Networks
<https://arxiv.org/abs/1511.05493>`__
.. math::
r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate}
\left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)
Parameters
----------
gate_nn : torch.nn.Module
A neural network that computes attention scores for each feature.
feat_nn : torch.nn.Module, optional
A neural network applied to each feature before combining them with attention
scores.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch as th
>>> from dgl.nn import GlobalAttentionPooling
>>>
>>> g1 = dgl.rand_graph(3, 4) # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5) # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
[0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
[0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6) # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5) # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
[0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
[0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
[0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> gate_nn = th.nn.Linear(5, 1) # the gate layer that maps node feature to scalar
>>> gap = GlobalAttentionPooling(gate_nn) # create a Global Attention Pooling layer
Case 1: Input a single graph
>>> gap(g1, g1_node_feats)
tensor([[0.7410, 0.6032, 0.8111, 0.5942, 0.4762]],
grad_fn=<SegmentReduceBackward>)
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.
>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats], 0)
>>>
>>> gap(batch_g, batch_f)
tensor([[0.7410, 0.6032, 0.8111, 0.5942, 0.4762],
[0.2417, 0.2743, 0.5054, 0.7356, 0.6146]],
grad_fn=<SegmentReduceBackward>)
Notes
-----
See our `GGNN example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/ggnn>`_
on how to use GatedGraphConv and GlobalAttentionPooling layer to build a Graph Neural
Networks that can solve Soduku.
"""
def __init__(self, gate_nn, feat_nn=None):
super(GlobalAttentionPooling, self).__init__()
self.gate_nn = gate_nn
self.feat_nn = feat_nn
[docs] def forward(self, graph, feat, get_attention=False):
r"""
Compute global attention pooling.
Parameters
----------
graph : DGLGraph
A DGLGraph or a batch of DGLGraphs.
feat : torch.Tensor
The input node feature with shape :math:`(N, D)` where :math:`N` is the
number of nodes in the graph, and :math:`D` means the size of features.
get_attention : bool, optional
Whether to return the attention values from gate_nn. Default to False.
Returns
-------
torch.Tensor
The output feature with shape :math:`(B, D)`, where :math:`B` refers
to the batch size.
torch.Tensor, optional
The attention values of shape :math:`(N, 1)`, where :math:`N` is the number of
nodes in the graph. This is returned only when :attr:`get_attention` is ``True``.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis."
feat = self.feat_nn(feat) if self.feat_nn else feat
graph.ndata['gate'] = gate
gate = softmax_nodes(graph, 'gate')
graph.ndata.pop('gate')
graph.ndata['r'] = feat * gate
readout = sum_nodes(graph, 'r')
graph.ndata.pop('r')
if get_attention:
return readout, gate
else:
return readout
[docs]class Set2Set(nn.Module):
r"""Set2Set operator from `Order Matters: Sequence to sequence for sets
<https://arxiv.org/pdf/1511.06391.pdf>`__
For each individual graph in the batch, set2set computes
.. math::
q_t &= \mathrm{LSTM} (q^*_{t-1})
\alpha_{i,t} &= \mathrm{softmax}(x_i \cdot q_t)
r_t &= \sum_{i=1}^N \alpha_{i,t} x_i
q^*_t &= q_t \Vert r_t
for this graph.
Parameters
----------
input_dim : int
The size of each input sample.
n_iters : int
The number of iterations.
n_layers : int
The number of recurrent layers.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch as th
>>> from dgl.nn import Set2Set
>>>
>>> g1 = dgl.rand_graph(3, 4) # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5) # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
[0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
[0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6) # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5) # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
[0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
[0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
[0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> s2s = Set2Set(5, 2, 1) # create a Set2Set layer(n_iters=2, n_layers=1)
Case 1: Input a single graph
>>> s2s(g1, g1_node_feats)
tensor([[-0.0235, -0.2291, 0.2654, 0.0376, 0.1349, 0.7560, 0.5822, 0.8199,
0.5960, 0.4760]], grad_fn=<CatBackward>)
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.
>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats], 0)
>>>
>>> s2s(batch_g, batch_f)
tensor([[-0.0235, -0.2291, 0.2654, 0.0376, 0.1349, 0.7560, 0.5822, 0.8199,
0.5960, 0.4760],
[-0.0483, -0.2010, 0.2324, 0.0145, 0.1361, 0.2703, 0.3078, 0.5529,
0.6876, 0.6399]], grad_fn=<CatBackward>)
Notes
-----
Set2Set is widely used in molecular property predictions, see
`dgl-lifesci's MPNN example <https://github.com/awslabs/dgl-lifesci/blob/
ecd95c905479ec048097777039cf9a19cfdcf223/python/dgllife/model/model_zoo/
mpnn_predictor.py>`__
on how to use DGL's Set2Set layer in graph property prediction applications.
"""
def __init__(self, input_dim, n_iters, n_layers):
super(Set2Set, self).__init__()
self.input_dim = input_dim
self.output_dim = 2 * input_dim
self.n_iters = n_iters
self.n_layers = n_layers
self.lstm = th.nn.LSTM(self.output_dim, self.input_dim, n_layers)
self.reset_parameters()
[docs] def reset_parameters(self):
"""Reinitialize learnable parameters."""
self.lstm.reset_parameters()
[docs] def forward(self, graph, feat):
r"""
Compute set2set pooling.
Parameters
----------
graph : DGLGraph
The input graph.
feat : torch.Tensor
The input feature with shape :math:`(N, D)` where :math:`N` is the
number of nodes in the graph, and :math:`D` means the size of features.
Returns
-------
torch.Tensor
The output feature with shape :math:`(B, D)`, where :math:`B` refers to
the batch size, and :math:`D` means the size of features.
"""
with graph.local_scope():
batch_size = graph.batch_size
h = (feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
feat.new_zeros((self.n_layers, batch_size, self.input_dim)))
q_star = feat.new_zeros(batch_size, self.output_dim)
for _ in range(self.n_iters):
q, h = self.lstm(q_star.unsqueeze(0), h)
q = q.view(batch_size, self.input_dim)
e = (feat * broadcast_nodes(graph, q)).sum(dim=-1, keepdim=True)
graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r')
q_star = th.cat([q, readout], dim=-1)
return q_star
def extra_repr(self):
"""Set the extra representation of the module.
which will come into effect when printing the model.
"""
summary = 'n_iters={n_iters}'
return summary.format(**self.__dict__)
def _gen_mask(lengths_x, lengths_y, max_len_x, max_len_y):
""" Generate binary mask array for given x and y input pairs.
Parameters
----------
lengths_x : Tensor
The int tensor indicates the segment information of x.
lengths_y : Tensor
The int tensor indicates the segment information of y.
max_len_x : int
The maximum element in lengths_x.
max_len_y : int
The maximum element in lengths_y.
Returns
-------
Tensor
the mask tensor with shape (batch_size, 1, max_len_x, max_len_y)
"""
device = lengths_x.device
# x_mask: (batch_size, max_len_x)
x_mask = th.arange(max_len_x, device=device).unsqueeze(0) < lengths_x.unsqueeze(1)
# y_mask: (batch_size, max_len_y)
y_mask = th.arange(max_len_y, device=device).unsqueeze(0) < lengths_y.unsqueeze(1)
# mask: (batch_size, 1, max_len_x, max_len_y)
mask = (x_mask.unsqueeze(-1) & y_mask.unsqueeze(-2)).unsqueeze(1)
return mask
class MultiHeadAttention(nn.Module):
r"""Multi-Head Attention block, used in Transformer, Set Transformer and so on
Parameters
----------
d_model : int
The feature size (input and output) in Multi-Head Attention layer.
num_heads : int
The number of heads.
d_head : int
The hidden size per head.
d_ff : int
The inner hidden size in the Feed-Forward Neural Network.
dropouth : float
The dropout rate of each sublayer.
dropouta : float
The dropout rate of attention heads.
Notes
-----
This module was used in SetTransformer layer.
"""
def __init__(self, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_head
self.d_ff = d_ff
self.proj_q = nn.Linear(d_model, num_heads * d_head, bias=False)
self.proj_k = nn.Linear(d_model, num_heads * d_head, bias=False)
self.proj_v = nn.Linear(d_model, num_heads * d_head, bias=False)
self.proj_o = nn.Linear(num_heads * d_head, d_model, bias=False)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropouth),
nn.Linear(d_ff, d_model)
)
self.droph = nn.Dropout(dropouth)
self.dropa = nn.Dropout(dropouta)
self.norm_in = nn.LayerNorm(d_model)
self.norm_inter = nn.LayerNorm(d_model)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, x, mem, lengths_x, lengths_mem):
"""
Compute multi-head self-attention.
Parameters
----------
x : torch.Tensor
The input tensor used to compute queries.
mem : torch.Tensor
The memory tensor used to compute keys and values.
lengths_x : list
The array of node numbers, used to segment x.
lengths_mem : list
The array of node numbers, used to segment mem.
"""
batch_size = len(lengths_x)
max_len_x = max(lengths_x)
max_len_mem = max(lengths_mem)
device = x.device
lengths_x = th.tensor(lengths_x, dtype=th.int64, device=device)
lengths_mem = th.tensor(lengths_mem, dtype=th.int64, device=device)
queries = self.proj_q(x).view(-1, self.num_heads, self.d_head)
keys = self.proj_k(mem).view(-1, self.num_heads, self.d_head)
values = self.proj_v(mem).view(-1, self.num_heads, self.d_head)
# padding to (B, max_len_x/mem, num_heads, d_head)
queries = F.pad_packed_tensor(queries, lengths_x, 0)
keys = F.pad_packed_tensor(keys, lengths_mem, 0)
values = F.pad_packed_tensor(values, lengths_mem, 0)
# attention score with shape (B, num_heads, max_len_x, max_len_mem)
e = th.einsum('bxhd,byhd->bhxy', queries, keys)
# normalize
e = e / np.sqrt(self.d_head)
# generate mask
mask = _gen_mask(lengths_x, lengths_mem, max_len_x, max_len_mem)
e = e.masked_fill(mask == 0, -float('inf'))
# apply softmax
alpha = th.softmax(e, dim=-1)
# the following line addresses the NaN issue, see
# https://github.com/dmlc/dgl/issues/2657
alpha = alpha.masked_fill(mask == 0, 0.)
# sum of value weighted by alpha
out = th.einsum('bhxy,byhd->bxhd', alpha, values)
# project to output
out = self.proj_o(
out.contiguous().view(batch_size, max_len_x, self.num_heads * self.d_head))
# pack tensor
out = F.pack_padded_tensor(out, lengths_x)
# intra norm
x = self.norm_in(x + out)
# inter norm
x = self.norm_inter(x + self.ffn(x))
return x
class SetAttentionBlock(nn.Module):
r"""SAB block from `Set Transformer: A Framework for Attention-based
Permutation-Invariant Neural Networks <https://arxiv.org/abs/1810.00825>`__
Parameters
----------
d_model : int
The feature size (input and output) in Multi-Head Attention layer.
num_heads : int
The number of heads.
d_head : int
The hidden size per head.
d_ff : int
The inner hidden size in the Feed-Forward Neural Network.
dropouth : float
The dropout rate of each sublayer.
dropouta : float
The dropout rate of attention heads.
Notes
-----
This module was used in SetTransformer layer.
"""
def __init__(self, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
super(SetAttentionBlock, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta)
def forward(self, feat, lengths):
"""
Compute a Set Attention Block.
Parameters
----------
feat : torch.Tensor
The input feature.
lengths : list
The array of node numbers, used to segment feat tensor.
"""
return self.mha(feat, feat, lengths, lengths)
class InducedSetAttentionBlock(nn.Module):
r"""ISAB block from `Set Transformer: A Framework for Attention-based
Permutation-Invariant Neural Networks <https://arxiv.org/abs/1810.00825>`__
Parameters
----------
m : int
The number of induced vectors.
d_model : int
The feature size (input and output) in Multi-Head Attention layer.
num_heads : int
The number of heads.
d_head : int
The hidden size per head.
d_ff : int
The inner hidden size in the Feed-Forward Neural Network.
dropouth : float
The dropout rate of each sublayer.
dropouta : float
The dropout rate of attention heads.
Notes
-----
This module was used in SetTransformer layer.
"""
def __init__(self, m, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
super(InducedSetAttentionBlock, self).__init__()
self.m = m
if m == 1:
dgl_warning("if m is set to 1, the parameters corresponding to query and key "
"projections would not get updated during training.")
self.d_model = d_model
self.inducing_points = nn.Parameter(
th.FloatTensor(m, d_model)
)
self.mha = nn.ModuleList([
MultiHeadAttention(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta) for _ in range(2)])
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
nn.init.xavier_uniform_(self.inducing_points)
def forward(self, feat, lengths):
"""
Compute an Induced Set Attention Block.
Parameters
----------
feat : torch.Tensor
The input feature.
lengths : list
The array of node numbers, used to segment feat tensor.
Returns
-------
torch.Tensor
The output feature
"""
batch_size = len(lengths)
query = self.inducing_points.repeat(batch_size, 1)
memory = self.mha[0](query, feat, [self.m] * batch_size, lengths)
return self.mha[1](feat, memory, lengths, [self.m] * batch_size)
def extra_repr(self):
"""Set the extra representation of the module.
which will come into effect when printing the model.
"""
shape_str = '({}, {})'.format(self.inducing_points.shape[0], self.inducing_points.shape[1])
return 'InducedVector: ' + shape_str
class PMALayer(nn.Module):
r"""Pooling by Multihead Attention from `Set Transformer: A Framework for Attention-based
Permutation-Invariant Neural Networks <https://arxiv.org/abs/1810.00825>`__
Parameters
----------
k : int
The number of seed vectors.
d_model : int
The feature size (input and output) in Multi-Head Attention layer.
num_heads : int
The number of heads.
d_head : int
The hidden size per head.
d_ff : int
The kernel size in FFN (Positionwise Feed-Forward Network) layer.
dropouth : float
The dropout rate of each sublayer.
dropouta : float
The dropout rate of attention heads.
Notes
-----
This module was used in SetTransformer layer.
"""
def __init__(self, k, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
super(PMALayer, self).__init__()
self.k = k
if k == 1:
dgl_warning("if k is set to 1, the parameters corresponding to query and key "
"projections would not get updated during training.")
self.d_model = d_model
self.seed_vectors = nn.Parameter(
th.FloatTensor(k, d_model)
)
self.mha = MultiHeadAttention(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropouth),
nn.Linear(d_ff, d_model)
)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
nn.init.xavier_uniform_(self.seed_vectors)
def forward(self, feat, lengths):
"""
Compute Pooling by Multihead Attention.
Parameters
----------
feat : torch.Tensor
The input feature.
lengths : list
The array of node numbers, used to segment feat tensor.
Returns
-------
torch.Tensor
The output feature
"""
batch_size = len(lengths)
query = self.seed_vectors.repeat(batch_size, 1)
return self.mha(query, self.ffn(feat), [self.k] * batch_size, lengths)
def extra_repr(self):
"""Set the extra representation of the module.
which will come into effect when printing the model.
"""
shape_str = '({}, {})'.format(self.seed_vectors.shape[0], self.seed_vectors.shape[1])
return 'SeedVector: ' + shape_str
[docs]class WeightAndSum(nn.Module):
"""Compute importance weights for atoms and perform a weighted sum.
Parameters
----------
in_feats : int
Input atom feature size
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch as th
>>> from dgl.nn import WeightAndSum
>>>
>>> g1 = dgl.rand_graph(3, 4) # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5) # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
[0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
[0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6) # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5) # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
[0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
[0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
[0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> weight_and_sum = WeightAndSum(5) # create a weight and sum layer(in_feats=16)
Case 1: Input a single graph
>>> weight_and_sum(g1, g1_node_feats)
tensor([[1.2194, 0.9490, 1.3235, 0.9609, 0.7710]],
grad_fn=<SegmentReduceBackward>)
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs' node features into one tensor.
>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> weight_and_sum(batch_g, batch_f)
tensor([[1.2194, 0.9490, 1.3235, 0.9609, 0.7710],
[0.5322, 0.5840, 1.0729, 1.3665, 1.2360]],
grad_fn=<SegmentReduceBackward>)
Notes
-----
WeightAndSum module was commonly used in molecular property prediction networks,
see the GCN predictor in `dgl-lifesci <https://github.com/awslabs/dgl-lifesci/blob/
ae0491431804611ba466ff413f69d435789dbfd5/python/dgllife/model/model_zoo/
gcn_predictor.py>`__
to understand how to use WeightAndSum layer to get the graph readout output.
"""
def __init__(self, in_feats):
super(WeightAndSum, self).__init__()
self.in_feats = in_feats
self.atom_weighting = nn.Sequential(
nn.Linear(in_feats, 1),
nn.Sigmoid()
)
[docs] def forward(self, g, feats):
"""Compute molecule representations out of atom representations
Parameters
----------
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, self.in_feats)
Representations for all atoms in the molecules
* N is the total number of atoms in all molecules
Returns
-------
FloatTensor of shape (B, self.in_feats)
Representations for B molecules
"""
with g.local_scope():
g.ndata['h'] = feats
g.ndata['w'] = self.atom_weighting(g.ndata['h'])
h_g_sum = sum_nodes(g, 'h', 'w')
return h_g_sum