Source code for dgl.nn.pytorch.conv.pnaconv

"""Torch Module for Principal Neighbourhood Aggregation Convolution Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import torch
import torch.nn as nn

def aggregate_mean(h):
    """mean aggregation"""
    return torch.mean(h, dim=1)

def aggregate_max(h):
    """max aggregation"""
    return torch.max(h, dim=1)[0]

def aggregate_min(h):
    """min aggregation"""
    return torch.min(h, dim=1)[0]

def aggregate_sum(h):
    """sum aggregation"""
    return torch.sum(h, dim=1)

def aggregate_std(h):
    """standard deviation aggregation"""
    return torch.sqrt(aggregate_var(h) + 1e-30)

def aggregate_var(h):
    """variance aggregation"""
    h_mean_squares = torch.mean(h * h, dim=1)
    h_mean = torch.mean(h, dim=1)
    var = torch.relu(h_mean_squares - h_mean * h_mean)
    return var

def _aggregate_moment(h, n):
    """moment aggregation: for each node (E[(X-E[X])^n])^{1/n}"""
    h_mean = torch.mean(h, dim=1, keepdim=True)
    h_n = torch.mean(torch.pow(h - h_mean, n), dim=1)
    rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + 1e-30, 1. / n)
    return rooted_h_n

def aggregate_moment_3(h):
    """moment aggregation with n=3"""
    return _aggregate_moment(h, n=3)

def aggregate_moment_4(h):
    """moment aggregation with n=4"""
    return _aggregate_moment(h, n=4)

def aggregate_moment_5(h):
    """moment aggregation with n=5"""
    return _aggregate_moment(h, n=5)

def scale_identity(h):
    """identity scaling (no scaling operation)"""
    return h

def scale_amplification(h, D, delta):
    """amplification scaling"""
    return h * (np.log(D + 1) / delta)

def scale_attenuation(h, D, delta):
    """attenuation scaling"""
    return h * (delta / np.log(D + 1))

AGGREGATORS = {
    'mean': aggregate_mean, 'sum': aggregate_sum, 'max': aggregate_max, 'min': aggregate_min,
    'std': aggregate_std, 'var': aggregate_var, 'moment3': aggregate_moment_3,
    'moment4': aggregate_moment_4, 'moment5': aggregate_moment_5
}
SCALERS = {
    'identity': scale_identity,
    'amplification': scale_amplification,
    'attenuation': scale_attenuation
}

class PNAConvTower(nn.Module):
    """A single PNA tower in PNA layers"""
    def __init__(self, in_size, out_size, aggregators, scalers,
        delta, dropout=0., edge_feat_size=0):
        super(PNAConvTower, self).__init__()
        self.in_size = in_size
        self.out_size = out_size
        self.aggregators = aggregators
        self.scalers = scalers
        self.delta = delta
        self.edge_feat_size = edge_feat_size

        self.M = nn.Linear(2 * in_size + edge_feat_size, in_size)
        self.U = nn.Linear((len(aggregators) * len(scalers) + 1) * in_size, out_size)
        self.dropout = nn.Dropout(dropout)
        self.batchnorm = nn.BatchNorm1d(out_size)

    def reduce_func(self, nodes):
        """reduce function for PNA layer:
        tensordot of multiple aggregation and scaling operations"""
        msg = nodes.mailbox['msg']
        degree = msg.size(1)
        h = torch.cat([AGGREGATORS[agg](msg) for agg in self.aggregators], dim=1)
        h = torch.cat([
            SCALERS[scaler](h, D=degree, delta=self.delta) if scaler != 'identity' else h
            for scaler in self.scalers
        ], dim=1)
        return {'h_neigh': h}

    def message(self, edges):
        """message function for PNA layer"""
        if self.edge_feat_size > 0:
            f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['a']], dim=-1)
        else:
            f = torch.cat([edges.src['h'], edges.dst['h']], dim=-1)
        return {'msg': self.M(f)}

    def forward(self, graph, node_feat, edge_feat=None):
        """compute the forward pass of a single tower in PNA convolution layer"""
        # calculate graph normalization factors
        snorm_n = torch.cat(
            [torch.ones(N, 1).to(node_feat) / N for N in graph.batch_num_nodes()],
            dim=0
        ).sqrt()
        with graph.local_scope():
            graph.ndata['h'] = node_feat
            if self.edge_feat_size > 0:
                assert edge_feat is not None, "Edge features must be provided."
                graph.edata['a'] = edge_feat

            graph.update_all(self.message, self.reduce_func)
            h = self.U(
                torch.cat([node_feat, graph.ndata['h_neigh']], dim=-1)
            )
            h = h * snorm_n
            return self.dropout(self.batchnorm(h))

[docs]class PNAConv(nn.Module): r"""Principal Neighbourhood Aggregation Layer from `Principal Neighbourhood Aggregation for Graph Nets <https://arxiv.org/abs/2004.05718>`__ A PNA layer is composed of multiple PNA towers. Each tower takes as input a split of the input features, and computes the message passing as below. .. math:: h_i^(l+1) = U(h_i^l, \oplus_{(i,j)\in E}M(h_i^l, e_{i,j}, h_j^l)) where :math:`h_i` and :math:`e_{i,j}` are node features and edge features, respectively. :math:`M` and :math:`U` are MLPs, taking the concatenation of input for computing output features. :math:`\oplus` represents the combination of various aggregators and scalers. Aggregators aggregate messages from neighbours and scalers scale the aggregated messages in different ways. :math:`\oplus` concatenates the output features of each combination. The output of multiple towers are concatenated and fed into a linear mixing layer for the final output. Parameters ---------- in_size : int Input feature size; i.e. the size of :math:`h_i^l`. out_size : int Output feature size; i.e. the size of :math:`h_i^{l+1}`. aggregators : list of str List of aggregation function names(each aggregator specifies a way to aggregate messages from neighbours), selected from: * ``mean``: the mean of neighbour messages * ``max``: the maximum of neighbour messages * ``min``: the minimum of neighbour messages * ``std``: the standard deviation of neighbour messages * ``var``: the variance of neighbour messages * ``sum``: the sum of neighbour messages * ``moment3``, ``moment4``, ``moment5``: the normalized moments aggregation :math:`(E[(X-E[X])^n])^{1/n}` scalers: list of str List of scaler function names, selected from: * ``identity``: no scaling * ``amplification``: multiply the aggregated message by :math:`\log(d+1)/\delta`, where :math:`d` is the degree of the node. * ``attenuation``: multiply the aggregated message by :math:`\delta/\log(d+1)` delta: float The degree-related normalization factor computed over the training set, used by scalers for normalization. :math:`E[\log(d+1)]`, where :math:`d` is the degree for each node in the training set. dropout: float, optional The dropout ratio. Default: 0.0. num_towers: int, optional The number of towers used. Default: 1. Note that in_size and out_size must be divisible by num_towers. edge_feat_size: int, optional The edge feature size. Default: 0. residual : bool, optional The bool flag that determines whether to add a residual connection for the output. Default: True. If in_size and out_size of the PNA conv layer are not the same, this flag will be set as False forcibly. Example ------- >>> import dgl >>> import torch as th >>> from dgl.nn import PNAConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = PNAConv(10, 10, ['mean', 'max', 'sum'], ['identity', 'amplification'], 2.5) >>> ret = conv(g, feat) """ def __init__(self, in_size, out_size, aggregators, scalers, delta, dropout=0., num_towers=1, edge_feat_size=0, residual=True): super(PNAConv, self).__init__() self.in_size = in_size self.out_size = out_size assert in_size % num_towers == 0, 'in_size must be divisible by num_towers' assert out_size % num_towers == 0, 'out_size must be divisible by num_towers' self.tower_in_size = in_size // num_towers self.tower_out_size = out_size // num_towers self.edge_feat_size = edge_feat_size self.residual = residual if self.in_size != self.out_size: self.residual = False self.towers = nn.ModuleList([ PNAConvTower( self.tower_in_size, self.tower_out_size, aggregators, scalers, delta, dropout=dropout, edge_feat_size=edge_feat_size ) for _ in range(num_towers) ]) self.mixing_layer = nn.Sequential( nn.Linear(out_size, out_size), nn.LeakyReLU() )
[docs] def forward(self, graph, node_feat, edge_feat=None): r""" Description ----------- Compute PNA layer. Parameters ---------- graph : DGLGraph The graph. node_feat : torch.Tensor The input feature of shape :math:`(N, h_n)`. :math:`N` is the number of nodes, and :math:`h_n` must be the same as in_size. edge_feat : torch.Tensor, optional The edge feature of shape :math:`(M, h_e)`. :math:`M` is the number of edges, and :math:`h_e` must be the same as edge_feat_size. Returns ------- torch.Tensor The output node feature of shape :math:`(N, h_n')` where :math:`h_n'` should be the same as out_size. """ h_cat = torch.cat([ tower( graph, node_feat[:, ti * self.tower_in_size: (ti + 1) * self.tower_in_size], edge_feat ) for ti, tower in enumerate(self.towers) ], dim=1) h_out = self.mixing_layer(h_cat) # add residual connection if self.residual: h_out = h_out + node_feat return h_out