Source code for dgl.data.pattern

""" PATTERNDataset for inductive learning. """
import os

from .dgl_dataset import DGLBuiltinDataset
from .utils import _get_dgl_url, load_graphs


[docs]class PATTERNDataset(DGLBuiltinDataset): r"""PATTERN dataset for graph pattern recognition task. Each graph G contains 5 communities with sizes randomly selected between [5, 35]. The SBM of each community is p = 0.5, q = 0.35, and the node features on G are generated with a uniform random distribution with a vocabulary of size 3, i.e. {0, 1, 2}. Then randomly generate 100 patterns P composed of 20 nodes with intra-probability :math:`p_P` = 0.5 and extra-probability :math:`q_P` = 0.5 (i.e. 50% of nodes in P are connected to G). The node features for P are also generated as a random signal with values {0, 1, 2}. The graphs are of sizes 44-188 nodes. The output node labels have value 1 if the node belongs to P and value 0 if it is in G. Reference `<https://arxiv.org/pdf/2003.00982.pdf>`_ Statistics: - Train examples: 10,000 - Valid examples: 2,000 - Test examples: 2,000 - Number of classes for each node: 2 Parameters ---------- mode : str Must be one of ('train', 'valid', 'test'). Default: 'train' raw_dir : str Raw file directory to download/contains the input data directory. Default: ~/.dgl/ force_reload : bool Whether to reload the dataset. Default: False verbose : bool Whether to print out progress information. Default: False transform : callable, optional A transform that takes in a :class:`~dgl.DGLGraph` object and returns a transformed version. The :class:`~dgl.DGLGraph` object will be transformed before every access. Attributes ---------- num_classes : int Number of classes for each node. Examples -------- >>> from dgl.data import PATTERNDataset >>> data = PATTERNDataset(mode='train') >>> data.num_classes 2 >>> len(trainset) 10000 >>> data[0] Graph(num_nodes=108, num_edges=4884, ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64), 'label': Scheme(shape=(), dtype=torch.int16)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}) """ def __init__( self, mode="train", raw_dir=None, force_reload=False, verbose=False, transform=None, ): assert mode in ["train", "valid", "test"] self.mode = mode _url = _get_dgl_url("dataset/SBM_PATTERN.zip") super(PATTERNDataset, self).__init__( name="pattern", url=_url, raw_dir=raw_dir, force_reload=force_reload, verbose=verbose, transform=transform, ) def process(self): self.load() @property def graph_path(self): return os.path.join( self.save_path, "SBM_PATTERN_{}.bin".format(self.mode) ) def has_cache(self): return os.path.exists(self.graph_path) def load(self): self._graphs, _ = load_graphs(self.graph_path) @property def num_classes(self): r"""Number of classes for each node.""" return 2
[docs] def __len__(self): r"""The number of examples in the dataset.""" return len(self._graphs)
[docs] def __getitem__(self, idx): r"""Get the idx^th sample. Parameters --------- idx : int The sample index. Returns ------- :class:`dgl.DGLGraph` graph structure, node features, node labels and edge features. - ``ndata['feat']``: node features - ``ndata['label']``: node labels - ``edata['feat']``: edge features """ if self._transform is None: return self._graphs[idx] else: return self._transform(self._graphs[idx])