"""A mini synthetic dataset for graph classification benchmark."""
import math, os
import networkx as nx
import numpy as np
from .dgl_dataset import DGLDataset
from .utils import save_graphs, load_graphs, makedirs
from .. import backend as F
from ..convert import from_networkx
from ..transforms import add_self_loop
__all__ = ['MiniGCDataset']
[docs]class MiniGCDataset(DGLDataset):
"""The synthetic graph classification dataset class.
The datset contains 8 different types of graphs.
- class 0 : cycle graph
- class 1 : star graph
- class 2 : wheel graph
- class 3 : lollipop graph
- class 4 : hypercube graph
- class 5 : grid graph
- class 6 : clique graph
- class 7 : circular ladder graph
Parameters
----------
num_graphs: int
Number of graphs in this dataset.
min_num_v: int
Minimum number of nodes for graphs
max_num_v: int
Maximum number of nodes for graphs
seed: int, default is 0
Random seed for data generation
transform: callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
num_graphs : int
Number of graphs
min_num_v : int
The minimum number of nodes
max_num_v : int
The maximum number of nodes
num_classes : int
The number of classes
Examples
--------
>>> data = MiniGCDataset(100, 16, 32, seed=0)
The dataset instance is an iterable
>>> len(data)
100
>>> g, label = data[64]
>>> g
Graph(num_nodes=20, num_edges=82,
ndata_schemes={}
edata_schemes={})
>>> label
tensor(5)
Batch the graphs and labels for mini-batch training
>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
>>> batched_labels = torch.tensor(labels)
>>> batched_graphs
Graph(num_nodes=356, num_edges=1060,
ndata_schemes={}
edata_schemes={})
"""
def __init__(self, num_graphs, min_num_v, max_num_v, seed=0,
save_graph=True, force_reload=False, verbose=False, transform=None):
self.num_graphs = num_graphs
self.min_num_v = min_num_v
self.max_num_v = max_num_v
self.seed = seed
self.save_graph = save_graph
super(MiniGCDataset, self).__init__(name="minigc", hash_key=(num_graphs, min_num_v, max_num_v, seed),
force_reload=force_reload,
verbose=verbose, transform=transform)
def process(self):
self.graphs = []
self.labels = []
self._generate(self.seed)
[docs] def __len__(self):
"""Return the number of graphs in the dataset."""
return len(self.graphs)
[docs] def __getitem__(self, idx):
"""Get the idx-th sample.
Parameters
---------
idx : int
The sample index.
Returns
-------
(:class:`dgl.Graph`, Tensor)
The graph and its label.
"""
if self._transform is None:
g = self.graphs[idx]
else:
g = self._transform(self.graphs[idx])
return g, self.labels[idx]
def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash))
if os.path.exists(graph_path):
return True
return False
def save(self):
"""save the graph list and the labels"""
if self.save_graph:
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash))
save_graphs(str(graph_path), self.graphs, {'labels': self.labels})
def load(self):
graphs, label_dict = load_graphs(os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash)))
self.graphs = graphs
self.labels = label_dict['labels']
@property
def num_classes(self):
"""Number of classes."""
return 8
def _generate(self, seed):
if seed is not None:
np.random.seed(seed)
self._gen_cycle(self.num_graphs // 8)
self._gen_star(self.num_graphs // 8)
self._gen_wheel(self.num_graphs // 8)
self._gen_lollipop(self.num_graphs // 8)
self._gen_hypercube(self.num_graphs // 8)
self._gen_grid(self.num_graphs // 8)
self._gen_clique(self.num_graphs // 8)
self._gen_circular_ladder(self.num_graphs - len(self.graphs))
# preprocess
for i in range(self.num_graphs):
# convert to DGLGraph, and add self loops
self.graphs[i] = add_self_loop(from_networkx(self.graphs[i]))
self.labels = F.tensor(np.array(self.labels).astype(np.int))
def _gen_cycle(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.cycle_graph(num_v)
self.graphs.append(g)
self.labels.append(0)
def _gen_star(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
# nx.star_graph(N) gives a star graph with N+1 nodes
g = nx.star_graph(num_v - 1)
self.graphs.append(g)
self.labels.append(1)
def _gen_wheel(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.wheel_graph(num_v)
self.graphs.append(g)
self.labels.append(2)
def _gen_lollipop(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
path_len = np.random.randint(2, num_v // 2)
g = nx.lollipop_graph(m=num_v - path_len, n=path_len)
self.graphs.append(g)
self.labels.append(3)
def _gen_hypercube(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.hypercube_graph(int(math.log(num_v, 2)))
g = nx.convert_node_labels_to_integers(g)
self.graphs.append(g)
self.labels.append(4)
def _gen_grid(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
assert num_v >= 4, 'We require a grid graph to contain at least two ' \
'rows and two columns, thus 4 nodes, got {:d} ' \
'nodes'.format(num_v)
n_rows = np.random.randint(2, num_v // 2)
n_cols = num_v // n_rows
g = nx.grid_graph([n_rows, n_cols])
g = nx.convert_node_labels_to_integers(g)
self.graphs.append(g)
self.labels.append(5)
def _gen_clique(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.complete_graph(num_v)
self.graphs.append(g)
self.labels.append(6)
def _gen_circular_ladder(self, n):
for _ in range(n):
num_v = np.random.randint(self.min_num_v, self.max_num_v)
g = nx.circular_ladder_graph(num_v // 2)
self.graphs.append(g)
self.labels.append(7)