Source code for dgl.model_zoo.chem.jtnn.jtnn_vae

# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

import rdkit.Chem as Chem

from dgl import batch, unbatch
from dgl.data.utils import get_download_dir

from .chemutils import (attach_mols_nx, copy_edit_mol, decode_stereo,
                        enum_assemble_nx, set_atommap)
from .jtmpn import DGLJTMPN
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .jtnn_dec import DGLJTNNDecoder
from .jtnn_enc import DGLJTNNEncoder
from .mol_tree import Vocab
from .mpn import DGLMPN
from .mpn import mol2dgl_single as mol2dgl_enc
from .nnutils import cuda, move_dgl_to_cuda


[docs]class DGLJTNNVAE(nn.Module): """ `Junction Tree Variational Autoencoder for Molecular Graph Generation <https://arxiv.org/abs/1802.04364>`__ """ def __init__(self, hidden_size, latent_size, depth, vocab=None, vocab_file=None): super(DGLJTNNVAE, self).__init__() if vocab is None: if vocab_file is None: vocab_file = '{}/jtnn/{}.txt'.format( get_download_dir(), 'vocab') self.vocab = Vocab([x.strip("\r\n ") for x in open(vocab_file)]) else: self.vocab = vocab self.hidden_size = hidden_size self.latent_size = latent_size self.depth = depth self.embedding = nn.Embedding(self.vocab.size(), hidden_size) self.mpn = DGLMPN(hidden_size, depth) self.jtnn = DGLJTNNEncoder(self.vocab, hidden_size, self.embedding) self.decoder = DGLJTNNDecoder( self.vocab, hidden_size, latent_size // 2, self.embedding) self.jtmpn = DGLJTMPN(hidden_size, depth) self.T_mean = nn.Linear(hidden_size, latent_size // 2) self.T_var = nn.Linear(hidden_size, latent_size // 2) self.G_mean = nn.Linear(hidden_size, latent_size // 2) self.G_var = nn.Linear(hidden_size, latent_size // 2) self.n_nodes_total = 0 self.n_passes = 0 self.n_edges_total = 0 self.n_tree_nodes_total = 0 @staticmethod def move_to_cuda(mol_batch): for t in mol_batch['mol_trees']: move_dgl_to_cuda(t) move_dgl_to_cuda(mol_batch['mol_graph_batch']) if 'cand_graph_batch' in mol_batch: move_dgl_to_cuda(mol_batch['cand_graph_batch']) if mol_batch.get('stereo_cand_graph_batch') is not None: move_dgl_to_cuda(mol_batch['stereo_cand_graph_batch']) def encode(self, mol_batch): mol_graphs = mol_batch['mol_graph_batch'] mol_vec = self.mpn(mol_graphs) mol_tree_batch, tree_vec = self.jtnn(mol_batch['mol_trees']) self.n_nodes_total += mol_graphs.number_of_nodes() self.n_edges_total += mol_graphs.number_of_edges() self.n_tree_nodes_total += sum(t.number_of_nodes() for t in mol_batch['mol_trees']) self.n_passes += 1 return mol_tree_batch, tree_vec, mol_vec def sample(self, tree_vec, mol_vec, e1=None, e2=None): tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs(self.T_var(tree_vec)) mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs(self.G_var(mol_vec)) epsilon = cuda(torch.randn(*tree_mean.shape)) if e1 is None else e1 tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = cuda(torch.randn(*mol_mean.shape)) if e2 is None else e2 mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon z_mean = torch.cat([tree_mean, mol_mean], 1) z_log_var = torch.cat([tree_log_var, mol_log_var], 1) return tree_vec, mol_vec, z_mean, z_log_var
[docs] def forward(self, mol_batch, beta=0, e1=None, e2=None): self.move_to_cuda(mol_batch) mol_trees = mol_batch['mol_trees'] batch_size = len(mol_trees) mol_tree_batch, tree_vec, mol_vec = self.encode(mol_batch) tree_vec, mol_vec, z_mean, z_log_var = self.sample( tree_vec, mol_vec, e1, e2) kl_loss = -0.5 * torch.sum( 1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size word_loss, topo_loss, word_acc, topo_acc = self.decoder( mol_trees, tree_vec) assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec) stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec) loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
def assm(self, mol_batch, mol_tree_batch, mol_vec): cands = [mol_batch['cand_graph_batch'], mol_batch['tree_mess_src_e'], mol_batch['tree_mess_tgt_e'], mol_batch['tree_mess_tgt_n']] cand_vec = self.jtmpn(cands, mol_tree_batch) cand_vec = self.G_mean(cand_vec) batch_idx = cuda(torch.LongTensor(mol_batch['cand_batch_idx'])) mol_vec = mol_vec[batch_idx] mol_vec = mol_vec.view(-1, 1, self.latent_size // 2) cand_vec = cand_vec.view(-1, self.latent_size // 2, 1) scores = (mol_vec @ cand_vec)[:, 0, 0] cnt, tot, acc = 0, 0, 0 all_loss = [] for i, mol_tree in enumerate(mol_batch['mol_trees']): comp_nodes = [node_id for node_id, node in mol_tree.nodes_dict.items() if len(node['cands']) > 1 and not node['is_leaf']] cnt += len(comp_nodes) # segmented accuracy and cross entropy for node_id in comp_nodes: node = mol_tree.nodes_dict[node_id] label = node['cands'].index(node['label']) ncand = len(node['cands']) cur_score = scores[tot:tot + ncand] tot += ncand if cur_score[label].item() >= cur_score.max().item(): acc += 1 label = cuda(torch.LongTensor([label])) all_loss.append( F.cross_entropy(cur_score.view(1, -1), label, size_average=False)) all_loss = sum(all_loss) / len(mol_batch['mol_trees']) return all_loss, acc / cnt def stereo(self, mol_batch, mol_vec): stereo_cands = mol_batch['stereo_cand_graph_batch'] batch_idx = mol_batch['stereo_cand_batch_idx'] labels = mol_batch['stereo_cand_labels'] lengths = mol_batch['stereo_cand_lengths'] if len(labels) == 0: # Only one stereoisomer exists; do nothing return cuda(torch.tensor(0.)), 1. batch_idx = cuda(torch.LongTensor(batch_idx)) stereo_cands = self.mpn(stereo_cands) stereo_cands = self.G_mean(stereo_cands) stereo_labels = mol_vec[batch_idx] scores = F.cosine_similarity(stereo_cands, stereo_labels) st, acc = 0, 0 all_loss = [] for label, le in zip(labels, lengths): cur_scores = scores[st:st + le] if cur_scores.data[label].item() >= cur_scores.max().item(): acc += 1 label = cuda(torch.LongTensor([label])) all_loss.append( F.cross_entropy(cur_scores.view(1, -1), label, size_average=False)) st += le all_loss = sum(all_loss) / len(labels) return all_loss, acc / len(labels) def decode(self, tree_vec, mol_vec): mol_tree, nodes_dict, effective_nodes = self.decoder.decode(tree_vec) effective_nodes_list = effective_nodes.tolist() nodes_dict = [nodes_dict[v] for v in effective_nodes_list] for i, (node_id, node) in enumerate(zip(effective_nodes_list, nodes_dict)): node['idx'] = i node['nid'] = i + 1 node['is_leaf'] = True if mol_tree.in_degree(node_id) > 1: node['is_leaf'] = False set_atommap(node['mol'], node['nid']) mol_tree_sg = mol_tree.subgraph(effective_nodes) mol_tree_sg.copy_from_parent() mol_tree_msg, _ = self.jtnn([mol_tree_sg]) mol_tree_msg = unbatch(mol_tree_msg)[0] mol_tree_msg.nodes_dict = nodes_dict cur_mol = copy_edit_mol(nodes_dict[0]['mol']) global_amap = [{}] + [{} for node in nodes_dict] global_amap[1] = {atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()} cur_mol = self.dfs_assemble( mol_tree_msg, mol_vec, cur_mol, global_amap, [], 0, None) if cur_mol is None: return None cur_mol = cur_mol.GetMol() set_atommap(cur_mol) cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol)) if cur_mol is None: return None smiles2D = Chem.MolToSmiles(cur_mol) stereo_cands = decode_stereo(smiles2D) if len(stereo_cands) == 1: return stereo_cands[0] stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands] stereo_cand_graphs, atom_x, bond_x = \ zip(*stereo_graphs) stereo_cand_graphs = batch(stereo_cand_graphs) atom_x = cuda(torch.cat(atom_x)) bond_x = cuda(torch.cat(bond_x)) stereo_cand_graphs.ndata['x'] = atom_x stereo_cand_graphs.edata['x'] = bond_x stereo_cand_graphs.edata['src_x'] = atom_x.new( bond_x.shape[0], atom_x.shape[1]).zero_() stereo_vecs = self.mpn(stereo_cand_graphs) stereo_vecs = self.G_mean(stereo_vecs) scores = F.cosine_similarity(stereo_vecs, mol_vec) _, max_id = scores.max(0) return stereo_cands[max_id.item()] def dfs_assemble(self, mol_tree_msg, mol_vec, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id): nodes_dict = mol_tree_msg.nodes_dict fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None cur_node = nodes_dict[cur_node_id] fa_nid = fa_node['nid'] if fa_node is not None else -1 prev_nodes = [fa_node] if fa_node is not None else [] children_node_id = [v for v in mol_tree_msg.successors(cur_node_id).tolist() if nodes_dict[v]['nid'] != fa_nid] children = [nodes_dict[v] for v in children_node_id] neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1] neighbors = sorted( neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True) singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1] neighbors = singletons + neighbors cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node['nid']] cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap) if len(cands) == 0: return None cand_smiles, cand_mols, cand_amap = list(zip(*cands)) cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols] cand_graphs, atom_x, bond_x, tree_mess_src_edges, \ tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec( cands) cand_graphs = batch(cand_graphs) atom_x = cuda(atom_x) bond_x = cuda(bond_x) cand_graphs.ndata['x'] = atom_x cand_graphs.edata['x'] = bond_x cand_graphs.edata['src_x'] = atom_x.new( bond_x.shape[0], atom_x.shape[1]).zero_() cand_vecs = self.jtmpn( (cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes), mol_tree_msg, ) cand_vecs = self.G_mean(cand_vecs) mol_vec = mol_vec.squeeze() scores = cand_vecs @ mol_vec _, cand_idx = torch.sort(scores, descending=True) backup_mol = Chem.RWMol(cur_mol) for i in range(len(cand_idx)): cur_mol = Chem.RWMol(backup_mol) pred_amap = cand_amap[cand_idx[i].item()] new_global_amap = copy.deepcopy(global_amap) for nei_id, ctr_atom, nei_atom in pred_amap: if nei_id == fa_nid: continue new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node['nid']][ctr_atom] cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap) new_mol = cur_mol.GetMol() new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol)) if new_mol is None: continue result = True for nei_node_id, nei_node in zip(children_node_id, children): if nei_node['is_leaf']: continue cur_mol = self.dfs_assemble( mol_tree_msg, mol_vec, cur_mol, new_global_amap, pred_amap, nei_node_id, cur_node_id) if cur_mol is None: result = False break if result: return cur_mol return None