# Source code for dgl.nn.tensorflow.glob

"""Tensorflow modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import tensorflow as tf
from tensorflow.keras import layers

from ...readout import sum_nodes, mean_nodes, max_nodes, \
softmax_nodes, topk_nodes

__all__ = ['SumPooling', 'AvgPooling',
'MaxPooling', 'SortPooling', 'WeightAndSum', 'GlobalAttentionPooling']

[docs]class SumPooling(layers.Layer):
r"""Apply sum pooling over the nodes in the graph.

.. math::
r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k
"""

def __init__(self):
super(SumPooling, self).__init__()

[docs]    def call(self, graph, feat):
r"""Compute sum pooling.

Parameters
----------
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:(N, *) where
:math:N is the number of nodes in the graph.

Returns
-------
tf.Tensor
The output feature with shape :math:(B, *), where
:math:B refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat

[docs]class AvgPooling(layers.Layer):
r"""Apply average pooling over the nodes in the graph.

.. math::
r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k
"""

def __init__(self):
super(AvgPooling, self).__init__()

[docs]    def call(self, graph, feat):
r"""Compute average pooling.

Parameters
----------
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:(N, *) where
:math:N is the number of nodes in the graph.

Returns
-------
tf.Tensor
The output feature with shape :math:(B, *), where
:math:B refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat

[docs]class MaxPooling(layers.Layer):
r"""Apply max pooling over the nodes in the graph.

.. math::
r^{(i)} = \max_{k=1}^{N_i}\left( x^{(i)}_k \right)
"""

def __init__(self):
super(MaxPooling, self).__init__()

[docs]    def call(self, graph, feat):
r"""Compute max pooling.

Parameters
----------
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:(N, *) where
:math:N is the number of nodes in the graph.

Returns
-------
tf.Tensor
The output feature with shape :math:(B, *), where
:math:B refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat

[docs]class SortPooling(layers.Layer):
r"""Apply Sort Pooling (An End-to-End Deep Learning Architecture for Graph Classification
<https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf>__) over the nodes in the graph.

Parameters
----------
k : int
The number of nodes to hold for each graph.
"""

def __init__(self, k):
super(SortPooling, self).__init__()
self.k = k

[docs]    def call(self, graph, feat):
r"""Compute sort pooling.

Parameters
----------
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:(N, D) where
:math:N is the number of nodes in the graph.

Returns
-------
tf.Tensor
The output feature with shape :math:(B, k * D), where
:math:B refers to the batch size.
"""
with graph.local_scope():
# Sort the feature of each node in ascending order.
feat = tf.sort(feat, -1)
graph.ndata['h'] = feat
# Sort nodes according to their last features.
ret = tf.reshape(topk_nodes(graph, 'h', self.k, sortby=-1), (
-1, self.k * feat.shape[-1]))
return ret

[docs]class GlobalAttentionPooling(layers.Layer):
r"""Apply Global Attention Pooling (Gated Graph Sequence Neural Networks
<https://arxiv.org/abs/1511.05493.pdf>__) over the nodes in the graph.

.. math::
r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate}
\left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)

Parameters
----------
gate_nn : tf.layers.Layer
A neural network that computes attention scores for each feature.
feat_nn : tf.layers.Layer, optional
A neural network applied to each feature before combining them
with attention scores.
"""

def __init__(self, gate_nn, feat_nn=None):
super(GlobalAttentionPooling, self).__init__()
self.gate_nn = gate_nn
self.feat_nn = feat_nn

[docs]    def call(self, graph, feat):
r"""Compute global attention pooling.

Parameters
----------
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:(N, D) where
:math:N is the number of nodes in the graph.

Returns
-------
tf.Tensor
The output feature with shape :math:(B, *), where
:math:B refers to the batch size.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis."
feat = self.feat_nn(feat) if self.feat_nn else feat

graph.ndata['gate'] = gate
gate = softmax_nodes(graph, 'gate')
graph.ndata.pop('gate')

graph.ndata['r'] = feat * gate
graph.ndata.pop('r')

class WeightAndSum(layers.Layer):
"""Compute importance weights for atoms and perform a weighted sum.

Parameters
----------
in_feats : int
Input atom feature size
"""

def __init__(self, in_feats):
super(WeightAndSum, self).__init__()
self.in_feats = in_feats
self.atom_weighting = tf.keras.Sequential(
layers.Dense(1),
layers.Activation(tf.nn.sigmoid)
)

def call(self, g, feats):
"""Compute molecule representations out of atom representations

Parameters
----------
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, self.in_feats)
Representations for all atoms in the molecules
* N is the total number of atoms in all molecules

Returns
-------
FloatTensor of shape (B, self.in_feats)
Representations for B molecules
"""
with g.local_scope():
g.ndata['h'] = feats
g.ndata['w'] = self.atom_weighting(g.ndata['h'])
h_g_sum = sum_nodes(g, 'h', 'w')

return h_g_sum