Source code for dgl.model_zoo.chem.mpnn

#!/usr/bin/env python
# coding: utf-8
# pylint: disable=C0103, C0111, E1101, W0612
"""Implementation of MPNN model."""
import torch.nn as nn
import torch.nn.functional as F

from ...nn.pytorch import Set2Set, NNConv

[docs]class MPNNModel(nn.Module): """ MPNN from `Neural Message Passing for Quantum Chemistry <>`__ Parameters ---------- node_input_dim : int Dimension of input node feature, default to be 15. edge_input_dim : int Dimension of input edge feature, default to be 15. output_dim : int Dimension of prediction, default to be 12. node_hidden_dim : int Dimension of node feature in hidden layers, default to be 64. edge_hidden_dim : int Dimension of edge feature in hidden layers, default to be 128. num_step_message_passing : int Number of message passing steps, default to be 6. num_step_set2set : int Number of set2set steps num_layer_set2set : int Number of set2set layers """ def __init__(self, node_input_dim=15, edge_input_dim=5, output_dim=12, node_hidden_dim=64, edge_hidden_dim=128, num_step_message_passing=6, num_step_set2set=6, num_layer_set2set=3): super(MPNNModel, self).__init__() self.num_step_message_passing = num_step_message_passing self.lin0 = nn.Linear(node_input_dim, node_hidden_dim) edge_network = nn.Sequential( nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(), nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim)) self.conv = NNConv(in_feats=node_hidden_dim, out_feats=node_hidden_dim, edge_func=edge_network, aggregator_type='sum') self.gru = nn.GRU(node_hidden_dim, node_hidden_dim) self.set2set = Set2Set(node_hidden_dim, num_step_set2set, num_layer_set2set) self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim) self.lin2 = nn.Linear(node_hidden_dim, output_dim)
[docs] def forward(self, g, n_feat, e_feat): """Predict molecule labels Parameters ---------- g : DGLGraph Input DGLGraph for molecule(s) n_feat : tensor of dtype float32 and shape (B1, D1) Node features. B1 for number of nodes and D1 for the node feature size. e_feat : tensor of dtype float32 and shape (B2, D2) Edge features. B2 for number of edges and D2 for the edge feature size. Returns ------- res : Predicted labels """ out = F.relu(self.lin0(n_feat)) # (B1, H1) h = out.unsqueeze(0) # (1, B1, H1) for i in range(self.num_step_message_passing): m = F.relu(self.conv(g, out, e_feat)) # (B1, H1) out, h = self.gru(m.unsqueeze(0), h) out = out.squeeze(0) out = self.set2set(g, out) out = F.relu(self.lin1(out)) out = self.lin2(out) return out