WikiCSDatasetο
- class dgl.data.WikiCSDataset(raw_dir=None, force_reload=False, verbose=False, transform=None)[source]ο
Bases:
DGLBuiltinDataset
Wiki-CS is a Wikipedia-based dataset for node classification from Wiki-CS: A Wikipedia-Based Benchmark for Graph Neural Networks
The dataset consists of nodes corresponding to Computer Science articles, with edges based on hyperlinks and 10 classes representing different branches of the field.
WikiCS dataset statistics:
Nodes: 11,701
Edges: 431,726 (note that the original dataset has 216,123 edges but DGL adds the reverse edges and removes the duplicate edges, hence with a different number)
Number of classes: 10
Node feature size: 300
Number of different train, validation, stopping splits: 20
Number of test split: 1
- Parameters:
raw_dir (str) β Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) β Whether to reload the dataset. Default: False
verbose (bool) β Whether to print out progress information. Default: False
transform (callable, optional) β A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
Examples
>>> from dgl.data import WikiCSDataset >>> dataset = WikiCSDataset() >>> dataset.num_classes 10 >>> g = dataset[0] >>> # get node feature >>> feat = g.ndata['feat'] >>> # get node labels >>> labels = g.ndata['label'] >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> stopping_mask = g.ndata['stopping_mask'] >>> test_mask = g.ndata['test_mask'] >>> # The shape of train, val and stopping masks are (num_nodes, num_splits). >>> # The num_splits is the number of different train, validation, stopping splits. >>> # Due to the number of test spilt is 1, the shape of test mask is (num_nodes,). >>> print(train_mask.shape, val_mask.shape, stopping_mask.shape) (11701, 20) (11701, 20) (11701, 20) >>> print(test_mask.shape) (11701,)
- __getitem__(idx)[source]ο
Get graph object
- Parameters:
idx (int) β Item index, WikiCSDataset has only one graph object
- Returns:
The graph contains:
ndata['feat']
: node featuresndata['label']
: node labelsndata['train_mask']
: train mask is for retrieving the nodes for training.ndata['val_mask']
: val mask is for retrieving the nodes for hyperparameter tuning.ndata['stopping_mask']
: stopping mask is for retrieving the nodes for early stopping criterion.ndata['test_mask']
: test mask is for retrieving the nodes for testing.
- Return type: