SquirrelDataset

class dgl.data.SquirrelDataset(raw_dir=None, force_reload=False, verbose=True, transform=None)[source]

Bases: dgl.data.geom_gcn.GeomGCNDataset

Wikipedia page-page network on squirrels from Multi-scale Attributed Node Embedding and later modified by Geom-GCN: Geometric Graph Convolutional Networks

Nodes represent articles from the English Wikipedia, edges reflect mutual links between them. Node features indicate the presence of particular nouns in the articles. The nodes were classified into 5 classes in terms of their average monthly traffic.

Statistics:

  • Nodes: 5201

  • Edges: 217073

  • Number of Classes: 5

  • 10 train/val/test splits

    • Train: 2496

    • Val: 1664

    • Test: 1041

Parameters
  • raw_dir (str, optional) – Raw file directory to store the processed data. Default: ~/.dgl/

  • force_reload (bool, optional) – Whether to re-download the data source. Default: False

  • verbose (bool, optional) – Whether to print progress information. Default: True

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access. Default: None

num_classes

Number of node classes

Type

int

Notes

The graph does not come with edges for both directions.

Examples

>>> from dgl.data import SquirrelDataset
>>> dataset = SquirrelDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get data split
>>> train_mask = g.ndata["train_mask"]
>>> val_mask = g.ndata["val_mask"]
>>> test_mask = g.ndata["test_mask"]
>>> # get labels
>>> label = g.ndata['label']
__getitem__(idx)

Gets the data object at index.

__len__()

The number of examples in the dataset.