CLUSTERDataset

class dgl.data.CLUSTERDataset(mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

Bases: dgl.data.dgl_dataset.DGLBuiltinDataset

CLUSTER dataset for semi-supervised clustering task.

Each graph contains 6 SBM clusters with sizes randomly selected between [5, 35] and probabilities p = 0.55, q = 0.25. The graphs are of sizes 40 -190 nodes. Each node can take an input feature value in {0, 1, 2, …, 6} and values 1~6 correspond to classes 0~5 respectively, while value 0 means that the class of the node is unknown. There is only one labeled node that is randomly assigned to each community and most node features are set to 0.

Reference https://arxiv.org/pdf/2003.00982.pdf

Statistics:

  • Train examples: 10,000

  • Valid examples: 1,000

  • Test examples: 1,000

  • Number of classes for each node: 6

Parameters
  • mode (str) – Must be one of (‘train’, ‘valid’, ‘test’). Default: ‘train’

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: False

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_classes

Number of classes for each node.

Type

int

Examples

>>> from dgl.data import CLUSTERDataset
>>>
>>> trainset = CLUSTERDataset(mode='train')
>>>
>>> trainset.num_classes
6
>>> len(trainset)
10000
>>> trainset[0]
Graph(num_nodes=117, num_edges=4104,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int16),
                     'feat': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)})
__getitem__(idx)[source]

Get the idx^th sample.

Parameters

idx (int) – The sample index.

Returns

graph structure, node features, node labels and edge features.

  • ndata['feat']: node features

  • ndata['label']: node labels

  • edata['feat']: edge features

Return type

dgl.DGLGraph

__len__()[source]

The number of examples in the dataset.