AsNodePredDataset

class dgl.data.AsNodePredDataset(dataset, split_ratio=None, target_ntype=None, **kwargs)[source]

Bases: dgl.data.dgl_dataset.DGLDataset

Repurpose a dataset for a standard semi-supervised transductive node prediction task.

The class converts a given dataset into a new dataset object that:

  • Contains only one graph, accessible from dataset[0].

  • The graph stores:

    • Node labels in g.ndata['label'].

    • Train/val/test masks in g.ndata['train_mask'], g.ndata['val_mask'], and g.ndata['test_mask'] respectively.

  • In addition, the dataset contains the following attributes:

    • num_classes, the number of classes to predict.

    • train_idx, val_idx, test_idx, train/val/test indexes.

If the input dataset contains heterogeneous graphs, users need to specify the target_ntype argument to indicate which node type to make predictions for. In this case:

  • Node labels are stored in g.nodes[target_ntype].data['label'].

  • Training masks are stored in g.nodes[target_ntype].data['train_mask']. So do validation and test masks.

The class will keep only the first graph in the provided dataset and generate train/val/test masks according to the given spplit ratio. The generated masks will be cached to disk for fast re-loading. If the provided split ratio differs from the cached one, it will re-process the dataset properly.

Parameters
  • dataset (DGLDataset) – The dataset to be converted.

  • split_ratio ((float, float, float), optional) – Split ratios for training, validation and test sets. Must sum to one.

  • target_ntype (str, optional) – The node type to add split mask for.

num_classes

Number of classes to predict.

Type

int

train_idx

An 1-D integer tensor of training node IDs.

Type

Tensor

val_idx

An 1-D integer tensor of validation node IDs.

Type

Tensor

test_idx

An 1-D integer tensor of test node IDs.

Type

Tensor

Examples

>>> ds = dgl.data.AmazonCoBuyComputerDataset()
>>> print(ds)
Dataset("amazon_co_buy_computer", num_graphs=1, save_path=...)
>>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1])
>>> print(new_ds)
Dataset("amazon_co_buy_computer-as-nodepred", num_graphs=1, save_path=...)
>>> print('train_mask' in new_ds[0].ndata)
True
__getitem__(idx)[source]

Gets the data object at index.

__len__()[source]

The number of examples in the dataset.