Source code for dgl.nn.pytorch.conv.edgeconv

"""Torch Module for EdgeConv Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from torch import nn

from .... import function as fn
from ....utils import expand_as_pair


[docs]class EdgeConv(nn.Module): r"""EdgeConv layer. Introduced in "`Dynamic Graph CNN for Learning on Point Clouds <https://arxiv.org/pdf/1801.07829>`__". Can be described as follows: .. math:: x_i^{(l+1)} = \max_{j \in \mathcal{N}(i)} \mathrm{ReLU}( \Theta \cdot (x_j^{(l)} - x_i^{(l)}) + \Phi \cdot x_i^{(l)}) where :math:`\mathcal{N}(i)` is the neighbor of :math:`i`. Parameters ---------- in_feat : int Input feature size. out_feat : int Output feature size. batch_norm : bool Whether to include batch normalization on messages. """ def __init__(self, in_feat, out_feat, batch_norm=False): super(EdgeConv, self).__init__() self.batch_norm = batch_norm self.theta = nn.Linear(in_feat, out_feat) self.phi = nn.Linear(in_feat, out_feat) if batch_norm: self.bn = nn.BatchNorm1d(out_feat) def message(self, edges): """The message computation function. """ theta_x = self.theta(edges.dst['x'] - edges.src['x']) phi_x = self.phi(edges.src['x']) return {'e': theta_x + phi_x}
[docs] def forward(self, g, h): """Forward computation Parameters ---------- g : DGLGraph The graph. h : Tensor or pair of tensors :math:`(N, D)` where :math:`N` is the number of nodes and :math:`D` is the number of feature dimensions. If a pair of tensors is given, the graph must be a uni-bipartite graph with only one edge type, and the two tensors must have the same dimensionality on all except the first axis. Returns ------- torch.Tensor New node features. """ with g.local_scope(): h_src, h_dst = expand_as_pair(h) g.srcdata['x'] = h_src g.dstdata['x'] = h_dst if not self.batch_norm: g.update_all(self.message, fn.max('e', 'x')) else: g.apply_edges(self.message) # Although the official implementation includes a per-edge # batch norm within EdgeConv, I choose to replace it with a # global batch norm for a number of reasons: # # (1) When the point clouds within each batch do not have the # same number of points, batch norm would not work. # # (2) Even if the point clouds always have the same number of # points, the points may as well be shuffled even with the # same (type of) object (and the official implementation # *does* shuffle the points of the same example for each # epoch). # # For example, the first point of a point cloud of an # airplane does not always necessarily reside at its nose. # # In this case, the learned statistics of each position # by batch norm is not as meaningful as those learned from # images. g.edata['e'] = self.bn(g.edata['e']) g.update_all(fn.copy_e('e', 'e'), fn.max('e', 'x')) return g.dstdata['x']