Source code for dgl.nn.pytorch.conv.atomicconv

"""Torch Module for Atomic Convolution Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import torch as th
import torch.nn as nn

class RadialPooling(nn.Module):
    r"""Radial pooling from `Atomic Convolutional Networks for
    Predicting Protein-Ligand Binding Affinity <>`__

    We denote the distance between atom :math:`i` and :math:`j` by :math:`r_{ij}`.

    A radial pooling layer transforms distances with radial filters. For radial filter
    indexed by :math:`k`, it projects edge distances with

    .. math::
        h_{ij}^{k} = \exp(-\gamma_{k}|r_{ij}-r_{k}|^2)

    If :math:`r_{ij} < c_k`,

    .. math::
        f_{ij}^{k} = 0.5 * \cos(\frac{\pi r_{ij}}{c_k} + 1),


    .. math::
        f_{ij}^{k} = 0.


    .. math::
        e_{ij}^{k} = h_{ij}^{k} * f_{ij}^{k}

    interaction_cutoffs : float32 tensor of shape (K)
        :math:`c_k` in the equations above. Roughly they can be considered as learnable cutoffs
        and two atoms are considered as connected if the distance between them is smaller than
        the cutoffs. K for the number of radial filters.
    rbf_kernel_means : float32 tensor of shape (K)
        :math:`r_k` in the equations above. K for the number of radial filters.
    rbf_kernel_scaling : float32 tensor of shape (K)
        :math:`\gamma_k` in the equations above. K for the number of radial filters.

    def __init__(
        self, interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling
        super(RadialPooling, self).__init__()

        self.interaction_cutoffs = nn.Parameter(
            interaction_cutoffs.reshape(-1, 1, 1), requires_grad=True
        self.rbf_kernel_means = nn.Parameter(
            rbf_kernel_means.reshape(-1, 1, 1), requires_grad=True
        self.rbf_kernel_scaling = nn.Parameter(
            rbf_kernel_scaling.reshape(-1, 1, 1), requires_grad=True

    def forward(self, distances):

        Apply the layer to transform edge distances.

        distances : Float32 tensor of shape (E, 1)
            Distance between end nodes of edges. E for the number of edges.

        Float32 tensor of shape (K, E, 1)
            Transformed edge distances. K for the number of radial filters.
        scaled_euclidean_distance = (
            -self.rbf_kernel_scaling * (distances - self.rbf_kernel_means) ** 2
        )  # (K, E, 1)
        rbf_kernel_results = th.exp(scaled_euclidean_distance)  # (K, E, 1)

        cos_values = 0.5 * (
            th.cos(np.pi * distances / self.interaction_cutoffs) + 1
        )  # (K, E, 1)
        cutoff_values = th.where(
            distances <= self.interaction_cutoffs,
        )  # (K, E, 1)

        # Note that there appears to be an inconsistency between the paper and
        # DeepChem's implementation. In the paper, the scaled_euclidean_distance first
        # gets multiplied by cutoff_values, followed by exponentiation. Here we follow
        # the practice of DeepChem.
        return rbf_kernel_results * cutoff_values

def msg_func(edges):

    Send messages along edges.

    edges : EdgeBatch
        A batch of edges.

    dict mapping 'm' to Float32 tensor of shape (E, K * T)
        Messages computed. E for the number of edges, K for the number of
        radial filters and T for the number of features to use
        (types of atomic number in the paper).
    return {
        "m": th.einsum("ij,ik->ijk", edges.src["hv"],["he"]).view(
            len(edges), -1

def reduce_func(nodes):

    Collect messages and update node representations.

    nodes : NodeBatch
        A batch of nodes.

    dict mapping 'hv_new' to Float32 tensor of shape (V, K * T)
        Updated node representations. V for the number of nodes, K for the number of
        radial filters and T for the number of features to use
        (types of atomic number in the paper).
    return {"hv_new": nodes.mailbox["m"].sum(1)}

[docs]class AtomicConv(nn.Module): r"""Atomic Convolution Layer from `Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity <>`__ Denoting the type of atom :math:`i` by :math:`z_i` and the distance between atom :math:`i` and :math:`j` by :math:`r_{ij}`. **Distance Transformation** An atomic convolution layer first transforms distances with radial filters and then perform a pooling operation. For radial filter indexed by :math:`k`, it projects edge distances with .. math:: h_{ij}^{k} = \exp(-\gamma_{k}|r_{ij}-r_{k}|^2) If :math:`r_{ij} < c_k`, .. math:: f_{ij}^{k} = 0.5 * \cos(\frac{\pi r_{ij}}{c_k} + 1), else, .. math:: f_{ij}^{k} = 0. Finally, .. math:: e_{ij}^{k} = h_{ij}^{k} * f_{ij}^{k} **Aggregation** For each type :math:`t`, each atom collects distance information from all neighbor atoms of type :math:`t`: .. math:: p_{i, t}^{k} = \sum_{j\in N(i)} e_{ij}^{k} * 1(z_j == t) Then concatenate the results for all RBF kernels and atom types. Parameters ---------- interaction_cutoffs : float32 tensor of shape (K) :math:`c_k` in the equations above. Roughly they can be considered as learnable cutoffs and two atoms are considered as connected if the distance between them is smaller than the cutoffs. K for the number of radial filters. rbf_kernel_means : float32 tensor of shape (K) :math:`r_k` in the equations above. K for the number of radial filters. rbf_kernel_scaling : float32 tensor of shape (K) :math:`\gamma_k` in the equations above. K for the number of radial filters. features_to_use : None or float tensor of shape (T) In the original paper, these are atomic numbers to consider, representing the types of atoms. T for the number of types of atomic numbers. Default to None. Note ---- * This convolution operation is designed for molecular graphs in Chemistry, but it might be possible to extend it to more general graphs. * There seems to be an inconsistency about the definition of :math:`e_{ij}^{k}` in the paper and the author's implementation. We follow the author's implementation. In the paper, :math:`e_{ij}^{k}` was defined as :math:`\exp(-\gamma_{k}|r_{ij}-r_{k}|^2 * f_{ij}^{k})`. * :math:`\gamma_{k}`, :math:`r_k` and :math:`c_k` are all learnable. Example ------- >>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import AtomicConv >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 1) >>> edist = th.ones(6, 1) >>> interaction_cutoffs = th.ones(3).float() * 2 >>> rbf_kernel_means = th.ones(3).float() >>> rbf_kernel_scaling = th.ones(3).float() >>> conv = AtomicConv(interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling) >>> res = conv(g, feat, edist) >>> res tensor([[0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000], [1.0000, 1.0000, 1.0000], [0.5000, 0.5000, 0.5000], [0.0000, 0.0000, 0.0000]], grad_fn=<ViewBackward>) """ def __init__( self, interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling, features_to_use=None, ): super(AtomicConv, self).__init__() self.radial_pooling = RadialPooling( interaction_cutoffs=interaction_cutoffs, rbf_kernel_means=rbf_kernel_means, rbf_kernel_scaling=rbf_kernel_scaling, ) if features_to_use is None: self.num_channels = 1 self.features_to_use = None else: self.num_channels = len(features_to_use) self.features_to_use = nn.Parameter( features_to_use, requires_grad=False )
[docs] def forward(self, graph, feat, distances): """ Description ----------- Apply the atomic convolution layer. Parameters ---------- graph : DGLGraph Topology based on which message passing is performed. feat : Float32 tensor of shape :math:`(V, 1)` Initial node features, which are atomic numbers in the paper. :math:`V` for the number of nodes. distances : Float32 tensor of shape :math:`(E, 1)` Distance between end nodes of edges. E for the number of edges. Returns ------- Float32 tensor of shape :math:`(V, K * T)` Updated node representations. :math:`V` for the number of nodes, :math:`K` for the number of radial filters, and :math:`T` for the number of types of atomic numbers. """ with graph.local_scope(): radial_pooled_values = self.radial_pooling(distances).to( feat ) # (K, E, 1) if self.features_to_use is not None: feat = (feat == self.features_to_use).to(feat) # (V, T) graph.ndata["hv"] = feat graph.edata["he"] = radial_pooled_values.transpose(1, 0).squeeze( -1 ) # (E, K) graph.update_all(msg_func, reduce_func) return graph.ndata["hv_new"].view( graph.number_of_nodes(), -1 ) # (V, K * T)