""" QM9 dataset for graph property prediction (regression) """
import os
import numpy as np
from .dgl_dataset import DGLDataset
from .utils import download, extract_archive, _get_dgl_url
from ..convert import graph as dgl_graph
from .. import backend as F
[docs]class QM9EdgeDataset(DGLDataset):
r"""QM9Edge dataset for graph property prediction (regression)
This dataset consists of 130,831 molecules with 19 regression targets.
Nodes correspond to atoms and edges correspond to bonds.
This dataset differs from :class:`~dgl.data.QM9Dataset` in the following aspects:
1. It includes the bonds in a molecule in the edges of the corresponding graph while the edges in :class:`~dgl.data.QM9Dataset` are purely distance-based.
2. It provides edge features, and node features in addition to the atoms' coordinates and atomic numbers.
3. It provides another 7 regression tasks(from 12 to 19).
This class is built based on a preprocessed version of the dataset, and we provide the preprocessing datails `here <https://gist.github.com/hengruizhang98/a2da30213b2356fff18b25385c9d3cd2>`_.
Reference:
- `"MoleculeNet: A Benchmark for Molecular Machine Learning" <https://arxiv.org/abs/1703.00564>`_
- `"Neural Message Passing for Quantum Chemistry" <https://arxiv.org/abs/1704.01212>`_
For
Statistics:
- Number of graphs: 130,831.
- Number of regression targets: 19.
Node attributes:
- pos: the 3D coordinates of each atom.
- attr: the 11D atom features.
Edge attributes:
- edge_attr: the 4D bond features.
Regression targets:
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Keys | Property | Description | Unit |
+========+==================================+===================================================================================+=============================================+
| mu | :math:`\mu` | Dipole moment | :math:`\textrm{D}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| alpha | :math:`\alpha` | Isotropic polarizability | :math:`{a_0}^3` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| homo | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| lumo | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| gap | :math:`\Delta \epsilon` | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| r2 | :math:`\langle R^2 \rangle` | Electronic spatial extent | :math:`{a_0}^2` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| zpve | :math:`\textrm{ZPVE}` | Zero point vibrational energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U0 | :math:`U_0` | Internal energy at 0K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U | :math:`U` | Internal energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| H | :math:`H` | Enthalpy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| G | :math:`G` | Free energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Cv | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U0_atom| :math:`U_0^{\textrm{ATOM}}` | Atomization energy at 0K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U_atom | :math:`U^{\textrm{ATOM}}` | Atomization energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| H_atom | :math:`H^{\textrm{ATOM}}` | Atomization enthalpy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| G_atom | :math:`G^{\textrm{ATOM}}` | Atomization free energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| A | :math:`A` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| B | :math:`B` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| C | :math:`C` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
Parameters
----------
label_keys : list
Names of the regression property, which should be a subset of the keys in the table above.
If not provided, it will load all the labels.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False.
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
num_labels : int
Number of labels for each graph, i.e. number of prediction tasks
Raises
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM9EdgeDataset(label_keys=['mu', 'alpha'])
>>> data.num_labels
2
>>> # iterate over the dataset
>>> for graph, labels in data:
... print(graph) # get information of each graph
... print(labels) # get labels of the corresponding graph
... # your code here...
>>>
"""
keys = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'U0_atom',
'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C']
map_dict = {}
for i, key in enumerate(keys):
map_dict[key] = i
def __init__(self,
label_keys=None,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
if label_keys is None:
self.label_keys = None
self.num_labels = 19
else:
self.label_keys = [self.map_dict[i] for i in label_keys]
self.num_labels = len(label_keys)
self._url = _get_dgl_url('dataset/qm9_edge.npz')
super(QM9EdgeDataset, self).__init__(name='qm9Edge',
raw_dir=raw_dir,
url=self._url,
force_reload=force_reload,
verbose=verbose,
transform=transform)
def download(self):
file_path = f'{self.raw_dir}/qm9_edge.npz'
if not os.path.exists(file_path):
download(self._url, path=file_path)
def process(self):
self.load()
def has_cache(self):
npz_path = f'{self.raw_dir}/qm9_edge.npz'
return os.path.exists(npz_path)
def save(self):
np.savez_compressed(f'{self.raw_dir}/qm9_edge.npz',
n_node=self.n_node,
n_edge=self.n_edge,
node_attr=self.node_attr,
node_pos=self.node_pos,
edge_attr=self.edge_attr,
src=self.src,
dst=self.dst,
targets=self.targets)
def load(self):
npz_path = f'{self.raw_dir}/qm9_edge.npz'
data_dict = np.load(npz_path, allow_pickle=True)
self.n_node = data_dict['n_node']
self.n_edge = data_dict['n_edge']
self.node_attr = data_dict['node_attr']
self.node_pos = data_dict['node_pos']
self.edge_attr = data_dict['edge_attr']
self.targets = data_dict['targets']
self.src = data_dict['src']
self.dst = data_dict['dst']
self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)])
self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)])
[docs] def __getitem__(self, idx):
r""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
The graph contains:
- ``ndata['pos']``: the coordinates of each atom
- ``ndata['attr']``: the features of each atom
- ``edata['edge_attr']``: the features of each bond
Tensor
Property values of molecular graphs
"""
pos = self.node_pos[self.n_cumsum[idx]:self.n_cumsum[idx+1]]
src = self.src[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
dst = self.dst[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
g = dgl_graph((src, dst))
g.ndata['pos'] = F.tensor(pos, dtype=F.data_type_dict['float32'])
g.ndata['attr'] = F.tensor(self.node_attr[self.n_cumsum[idx]:self.n_cumsum[idx+1]], dtype=F.data_type_dict['float32'])
g.edata['edge_attr'] = F.tensor(self.edge_attr[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]], dtype=F.data_type_dict['float32'])
label = F.tensor(self.targets[idx][self.label_keys], dtype=F.data_type_dict['float32'])
if self._transform is not None:
g = self._transform(g)
return g, label
[docs] def __len__(self):
r""" Number of graphs in the dataset.
Returns
-------
int
"""
return self.n_node.shape[0]
QM9Edge = QM9EdgeDataset