import os
from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, load_graphs
[docs]class ZINCDataset(DGLBuiltinDataset):
r"""ZINC dataset for the graph regression task.
A subset (12K) of ZINC molecular graphs (250K) dataset is used to
regress a molecular property known as the constrained solubility.
For each molecular graph, the node features are the types of heavy
atoms, between which the edge features are the types of bonds.
Each graph contains 9-37 nodes and 16-84 edges.
Reference `<https://arxiv.org/pdf/2003.00982.pdf>`_
Statistics:
Train examples: 10,000
Valid examples: 1,000
Test examples: 1,000
Average number of nodes: 23.16
Average number of edges: 39.83
Number of atom types: 28
Number of bond types: 4
Parameters
----------
mode : str, optional
Should be chosen from ["train", "valid", "test"]
Default: "train".
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: "~/.dgl/".
force_reload : bool
Whether to reload the dataset.
Default: False.
verbose : bool
Whether to print out progress information.
Default: False.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
num_atom_types : int
Number of atom types.
num_bond_types : int
Number of bond types.
Examples
---------
>>> from dgl.data import ZINCDataset
>>> training_set = ZINCDataset(mode="train")
>>> training_set.num_atom_types
28
>>> len(training_set)
10000
>>> graph, label = training_set[0]
>>> graph
Graph(num_nodes=29, num_edges=64,
ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)})
"""
def __init__(
self,
mode="train",
raw_dir=None,
force_reload=False,
verbose=False,
transform=None,
):
self._url = _get_dgl_url("dataset/ZINC12k.zip")
self.mode = mode
super(ZINCDataset, self).__init__(
name="zinc",
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def process(self):
self.load()
@property
def graph_path(self):
return os.path.join(self.save_path, "ZincDGL_{}.bin".format(self.mode))
def has_cache(self):
return os.path.exists(self.graph_path)
def load(self):
self._graphs, self._labels = load_graphs(self.graph_path)
@property
def num_atom_types(self):
return 28
@property
def num_bond_types(self):
return 4
[docs] def __len__(self):
return len(self._graphs)
[docs] def __getitem__(self, idx):
r"""Get one example by index.
Parameters
----------
idx : int
The sample index.
Returns
-------
dgl.DGLGraph
Each graph contains:
- ``ndata['feat']``: Types of heavy atoms as node features
- ``edata['feat']``: Types of bonds as edge features
Tensor
Constrained solubility as graph label
"""
labels = self._labels["g_label"]
if self._transform is None:
return self._graphs[idx], labels[idx]
else:
return self._transform(self._graphs[idx]), labels[idx]