"""QM9 dataset for graph property prediction (regression)."""
import os
import numpy as np
import scipy.sparse as sp
from .dgl_dataset import DGLDataset
from .utils import download, _get_dgl_url
from ..convert import graph as dgl_graph
from ..transform import to_bidirected
from .. import backend as F
[docs]class QM9Dataset(DGLDataset):
r"""QM9 dataset for graph property prediction (regression)
This dataset consists of 13,0831 molecules with 12 regression targets.
Node means atom and edge means bond.
Reference: `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_,
`"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
Statistics:
- Number of graphs: 13,0831
- Number of regression targets: 12
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Keys | Property | Description | Unit |
+========+==================================+===================================================================================+=============================================+
| mu | :math:`\mu` | Dipole moment | :math:`\textrm{D}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| alpha | :math:`\alpha` | Isotropic polarizability | :math:`{a_0}^3` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| homo | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| lumo | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| gap | :math:`\Delta \epsilon` | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| r2 | :math:`\langle R^2 \rangle` | Electronic spatial extent | :math:`{a_0}^2` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| zpve | :math:`\textrm{ZPVE}` | Zero point vibrational energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U0 | :math:`U_0` | Internal energy at 0K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U | :math:`U` | Internal energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| H | :math:`H` | Enthalpy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| G | :math:`G` | Free energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Cv | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
Parameters
----------
label_keys: list
Names of the regression property, which should be a subset of the keys in the table above.
cutoff: float
Cutoff distance for interatomic interactions, i.e. two atoms are connected in the corresponding graph if the distance between them is no larger than this.
Default: 5.0 Angstrom
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_labels : int
Number of labels for each graph, i.e. number of prediction tasks
Raises
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
>>> data.num_labels
2
>>>
>>> # iterate over the dataset
>>> for g, label in data:
... R = g.ndata['R'] # get coordinates of each atom
... Z = g.ndata['Z'] # get atomic numbers of each atom
... # your code here...
>>>
"""
def __init__(self,
label_keys,
cutoff=5.0,
raw_dir=None,
force_reload=False,
verbose=False):
self.cutoff = cutoff
self.label_keys = label_keys
self._url = _get_dgl_url('dataset/qm9_eV.npz')
super(QM9Dataset, self).__init__(name='qm9',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
npz_path = f'{self.raw_dir}/qm9_eV.npz'
data_dict = np.load(npz_path, allow_pickle=True)
# data_dict['N'] contains the number of atoms in each molecule.
# Atomic properties (Z and R) of all molecules are concatenated as single tensors,
# so you need this value to select the correct atoms for each molecule.
self.N = data_dict['N']
self.R = data_dict['R']
self.Z = data_dict['Z']
self.label = np.stack([data_dict[key] for key in self.label_keys], axis=1)
self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])
def download(self):
file_path = f'{self.raw_dir}/qm9_eV.npz'
if not os.path.exists(file_path):
download(self._url, path=file_path)
@property
def num_labels(self):
r"""
Returns
--------
int
Number of labels for each graph, i.e. number of prediction tasks.
"""
return self.label.shape[1]
[docs] def __getitem__(self, idx):
r""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
The graph contains:
- ``ndata['R']``: the coordinates of each atom
- ``ndata['Z']``: the atomic number
Tensor
Property values of molecular graphs
"""
label = F.tensor(self.label[idx], dtype=F.data_type_dict['float32'])
n_atoms = self.N[idx]
R = self.R[self.N_cumsum[idx]:self.N_cumsum[idx + 1]]
dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1)
adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(n_atoms, dtype=np.bool)
adj = adj.tocoo()
u, v = F.tensor(adj.row), F.tensor(adj.col)
g = dgl_graph((u, v))
g = to_bidirected(g)
g.ndata['R'] = F.tensor(R, dtype=F.data_type_dict['float32'])
g.ndata['Z'] = F.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]],
dtype=F.data_type_dict['int64'])
return g, label
[docs] def __len__(self):
r"""Number of graphs in the dataset.
Return
-------
int
"""
return self.label.shape[0]
QM9 = QM9Dataset