# Source code for dgl.nn.pytorch.conv.pnaconv

"""Torch Module for Principal Neighbourhood Aggregation Convolution Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import torch
import torch.nn as nn

def aggregate_mean(h):
"""mean aggregation"""

def aggregate_max(h):
"""max aggregation"""

def aggregate_min(h):
"""min aggregation"""

def aggregate_sum(h):
"""sum aggregation"""

def aggregate_std(h):
"""standard deviation aggregation"""

def aggregate_var(h):
"""variance aggregation"""
h_mean_squares = torch.mean(h * h, dim=1)
h_mean = torch.mean(h, dim=1)
var = torch.relu(h_mean_squares - h_mean * h_mean)
return var

def _aggregate_moment(h, n):
"""moment aggregation: for each node (E[(X-E[X])^n])^{1/n}"""
h_mean = torch.mean(h, dim=1, keepdim=True)
h_n = torch.mean(torch.pow(h - h_mean, n), dim=1)
rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + 1e-30, 1. / n)
return rooted_h_n

def aggregate_moment_3(h):
"""moment aggregation with n=3"""
return _aggregate_moment(h, n=3)

def aggregate_moment_4(h):
"""moment aggregation with n=4"""
return _aggregate_moment(h, n=4)

def aggregate_moment_5(h):
"""moment aggregation with n=5"""
return _aggregate_moment(h, n=5)

def scale_identity(h):
"""identity scaling (no scaling operation)"""
return h

def scale_amplification(h, D, delta):
"""amplification scaling"""
return h * (np.log(D + 1) / delta)

def scale_attenuation(h, D, delta):
"""attenuation scaling"""
return h * (delta / np.log(D + 1))

AGGREGATORS = {
'mean': aggregate_mean, 'sum': aggregate_sum, 'max': aggregate_max, 'min': aggregate_min,
'std': aggregate_std, 'var': aggregate_var, 'moment3': aggregate_moment_3,
'moment4': aggregate_moment_4, 'moment5': aggregate_moment_5
}
SCALERS = {
'identity': scale_identity,
'amplification': scale_amplification,
'attenuation': scale_attenuation
}

class PNAConvTower(nn.Module):
"""A single PNA tower in PNA layers"""
def __init__(self, in_size, out_size, aggregators, scalers,
delta, dropout=0., edge_feat_size=0):
super(PNAConvTower, self).__init__()
self.in_size = in_size
self.out_size = out_size
self.aggregators = aggregators
self.scalers = scalers
self.delta = delta
self.edge_feat_size = edge_feat_size

self.M = nn.Linear(2 * in_size + edge_feat_size, in_size)
self.U = nn.Linear((len(aggregators) * len(scalers) + 1) * in_size, out_size)
self.dropout = nn.Dropout(dropout)
self.batchnorm = nn.BatchNorm1d(out_size)

def reduce_func(self, nodes):
"""reduce function for PNA layer:
tensordot of multiple aggregation and scaling operations"""
msg = nodes.mailbox['msg']
degree = msg.size(1)
h = torch.cat([AGGREGATORS[agg](msg) for agg in self.aggregators], dim=1)
h = torch.cat([
SCALERS[scaler](h, D=degree, delta=self.delta) if scaler != 'identity' else h
for scaler in self.scalers
], dim=1)
return {'h_neigh': h}

def message(self, edges):
"""message function for PNA layer"""
if self.edge_feat_size > 0:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['a']], dim=-1)
else:
f = torch.cat([edges.src['h'], edges.dst['h']], dim=-1)
return {'msg': self.M(f)}

def forward(self, graph, node_feat, edge_feat=None):
"""compute the forward pass of a single tower in PNA convolution layer"""
# calculate graph normalization factors
snorm_n = torch.cat(
[torch.ones(N, 1).to(node_feat) / N for N in graph.batch_num_nodes()],
dim=0
).sqrt()
with graph.local_scope():
graph.ndata['h'] = node_feat
if self.edge_feat_size > 0:
assert edge_feat is not None, "Edge features must be provided."
graph.edata['a'] = edge_feat

graph.update_all(self.message, self.reduce_func)
h = self.U(
torch.cat([node_feat, graph.ndata['h_neigh']], dim=-1)
)
h = h * snorm_n
return self.dropout(self.batchnorm(h))

[docs]class PNAConv(nn.Module):
r"""Principal Neighbourhood Aggregation Layer from Principal Neighbourhood Aggregation
for Graph Nets <https://arxiv.org/abs/2004.05718>__

A PNA layer is composed of multiple PNA towers. Each tower takes as input a split of the
input features, and computes the message passing as below.

.. math::
h_i^(l+1) = U(h_i^l, \oplus_{(i,j)\in E}M(h_i^l, e_{i,j}, h_j^l))

where :math:h_i and :math:e_{i,j} are node features and edge features, respectively.
:math:M and :math:U are MLPs, taking the concatenation of input for computing
output features. :math:\oplus represents the combination of various aggregators
and scalers. Aggregators aggregate messages from neighbours and scalers scale the
aggregated messages in different ways. :math:\oplus concatenates the output features
of each combination.

The output of multiple towers are concatenated and fed into a linear mixing layer for the
final output.

Parameters
----------
in_size : int
Input feature size; i.e. the size of :math:h_i^l.
out_size : int
Output feature size; i.e. the size of :math:h_i^{l+1}.
aggregators : list of str
List of aggregation function names(each aggregator specifies a way to aggregate
messages from neighbours), selected from:

* mean: the mean of neighbour messages

* max: the maximum of neighbour messages

* min: the minimum of neighbour messages

* std: the standard deviation of neighbour messages

* var: the variance of neighbour messages

* sum: the sum of neighbour messages

* moment3, moment4, moment5: the normalized moments aggregation
:math:(E[(X-E[X])^n])^{1/n}
scalers: list of str
List of scaler function names, selected from:

* identity: no scaling

* amplification: multiply the aggregated message by :math:\log(d+1)/\delta,
where :math:d is the degree of the node.

* attenuation: multiply the aggregated message by :math:\delta/\log(d+1)
delta: float
The degree-related normalization factor computed over the training set, used by scalers
for normalization. :math:E[\log(d+1)], where :math:d is the degree for each node
in the training set.
dropout: float, optional
The dropout ratio. Default: 0.0.
num_towers: int, optional
The number of towers used. Default: 1. Note that in_size and out_size must be divisible
by num_towers.
edge_feat_size: int, optional
The edge feature size. Default: 0.
residual : bool, optional
The bool flag that determines whether to add a residual connection for the
output. Default: True. If in_size and out_size of the PNA conv layer are not
the same, this flag will be set as False forcibly.

Example
-------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import PNAConv
>>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = th.ones(6, 10)
>>> conv = PNAConv(10, 10, ['mean', 'max', 'sum'], ['identity', 'amplification'], 2.5)
>>> ret = conv(g, feat)
"""
def __init__(self, in_size, out_size, aggregators, scalers, delta,
dropout=0., num_towers=1, edge_feat_size=0, residual=True):
super(PNAConv, self).__init__()

self.in_size = in_size
self.out_size = out_size
assert in_size % num_towers == 0, 'in_size must be divisible by num_towers'
assert out_size % num_towers == 0, 'out_size must be divisible by num_towers'
self.tower_in_size = in_size // num_towers
self.tower_out_size = out_size // num_towers
self.edge_feat_size = edge_feat_size
self.residual = residual
if self.in_size != self.out_size:
self.residual = False

self.towers = nn.ModuleList([
PNAConvTower(
self.tower_in_size, self.tower_out_size,
aggregators, scalers, delta,
dropout=dropout, edge_feat_size=edge_feat_size
) for _ in range(num_towers)
])

self.mixing_layer = nn.Sequential(
nn.Linear(out_size, out_size),
nn.LeakyReLU()
)

[docs]    def forward(self, graph, node_feat, edge_feat=None):
r"""
Description
-----------
Compute PNA layer.

Parameters
----------
graph : DGLGraph
The graph.
node_feat : torch.Tensor
The input feature of shape :math:(N, h_n). :math:N is the number of
nodes, and :math:h_n must be the same as in_size.
edge_feat : torch.Tensor, optional
The edge feature of shape :math:(M, h_e). :math:M is the number of
edges, and :math:h_e must be the same as edge_feat_size.

Returns
-------
torch.Tensor
The output node feature of shape :math:(N, h_n') where :math:h_n'
should be the same as out_size.
"""
h_cat = torch.cat([
tower(
graph,
node_feat[:, ti * self.tower_in_size: (ti + 1) * self.tower_in_size],
edge_feat
)
for ti, tower in enumerate(self.towers)
], dim=1)
h_out = self.mixing_layer(h_cat)