PATTERNDataset¶
-
class
dgl.data.
PATTERNDataset
(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[source]¶ Bases:
dgl.data.dgl_dataset.DGLBuiltinDataset
PATTERN dataset for graph pattern recognition task.
Each graph G contains 5 communities with sizes randomly selected between [5, 35]. The SBM of each community is p = 0.5, q = 0.35, and the node features on G are generated with a uniform random distribution with a vocabulary of size 3, i.e. {0, 1, 2}. Then randomly generate 100 patterns P composed of 20 nodes with intra-probability \(p_P\) = 0.5 and extra-probability \(q_P\) = 0.5 (i.e. 50% of nodes in P are connected to G). The node features for P are also generated as a random signal with values {0, 1, 2}. The graphs are of sizes 44-188 nodes. The output node labels have value 1 if the node belongs to P and value 0 if it is in G.
Reference https://arxiv.org/pdf/2003.00982.pdf
Statistics:
Train examples: 10,000
Valid examples: 2,000
Test examples: 2,000
Number of classes for each node: 2
- Parameters
mode (str) – Must be one of (‘train’, ‘valid’, ‘test’). Default: ‘train’
raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) – Whether to reload the dataset. Default: False
verbose (bool) – Whether to print out progress information. Default: False
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
Examples
>>> from dgl.data import PATTERNDataset >>> data = PATTERNDataset(mode='train') >>> data.num_classes 2 >>> len(trainset) 10000 >>> data[0] Graph(num_nodes=108, num_edges=4884, ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64), 'label': Scheme(shape=(), dtype=torch.int16)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})