WikiCSDataset

class dgl.data.WikiCSDataset(raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

Bases: dgl.data.dgl_dataset.DGLBuiltinDataset

Wiki-CS is a Wikipedia-based dataset for node classification from Wiki-CS: A Wikipedia-Based Benchmark for Graph Neural Networks

The dataset consists of nodes corresponding to Computer Science articles, with edges based on hyperlinks and 10 classes representing different branches of the field.

WikiCS dataset statistics:

  • Nodes: 11,701

  • Edges: 431,726 (note that the original dataset has 216,123 edges but DGL adds the reverse edges and removes the duplicate edges, hence with a different number)

  • Number of classes: 10

  • Node feature size: 300

  • Number of different train, validation, stopping splits: 20

  • Number of test split: 1

Parameters
  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: False

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_classes

Number of node classes

Type

int

Examples

>>> from dgl.data import WikiCSDataset
>>> dataset = WikiCSDataset()
>>> dataset.num_classes
10
>>> g = dataset[0]
>>> # get node feature
>>> feat = g.ndata['feat']
>>> # get node labels
>>> labels = g.ndata['label']
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> stopping_mask = g.ndata['stopping_mask']
>>> test_mask = g.ndata['test_mask']
>>> # The shape of train, val and stopping masks are (num_nodes, num_splits).
>>> # The num_splits is the number of different train, validation, stopping splits.
>>> # Due to the number of test spilt is 1, the shape of test mask is (num_nodes,).
>>> print(train_mask.shape, val_mask.shape, stopping_mask.shape)
(11701, 20) (11701, 20) (11701, 20)
>>> print(test_mask.shape)
(11701,)
__getitem__(idx)[source]

Get graph object

Parameters

idx (int) – Item index, WikiCSDataset has only one graph object

Returns

The graph contains:

  • ndata['feat']: node features

  • ndata['label']: node labels

  • ndata['train_mask']: train mask is for retrieving the nodes for training.

  • ndata['val_mask']: val mask is for retrieving the nodes for hyperparameter tuning.

  • ndata['stopping_mask']: stopping mask is for retrieving the nodes for early stopping criterion.

  • ndata['test_mask']: test mask is for retrieving the nodes for testing.

Return type

dgl.DGLGraph

__len__()[source]

The number of graphs in the dataset.