AsNodePredDataset¶
-
class
dgl.data.
AsNodePredDataset
(dataset, split_ratio=None, target_ntype=None, **kwargs)[source]¶ Bases:
dgl.data.dgl_dataset.DGLDataset
Repurpose a dataset for a standard semi-supervised transductive node prediction task.
The class converts a given dataset into a new dataset object that:
Contains only one graph, accessible from
dataset[0]
.The graph stores:
Node labels in
g.ndata['label']
.Train/val/test masks in
g.ndata['train_mask']
,g.ndata['val_mask']
, andg.ndata['test_mask']
respectively.
In addition, the dataset contains the following attributes:
num_classes
, the number of classes to predict.train_idx
,val_idx
,test_idx
, train/val/test indexes.
If the input dataset contains heterogeneous graphs, users need to specify the
target_ntype
argument to indicate which node type to make predictions for. In this case:Node labels are stored in
g.nodes[target_ntype].data['label']
.Training masks are stored in
g.nodes[target_ntype].data['train_mask']
. So do validation and test masks.
The class will keep only the first graph in the provided dataset and generate train/val/test masks according to the given spplit ratio. The generated masks will be cached to disk for fast re-loading. If the provided split ratio differs from the cached one, it will re-process the dataset properly.
- Parameters
dataset (DGLDataset) – The dataset to be converted.
split_ratio ((float, float, float), optional) – Split ratios for training, validation and test sets. Must sum to one.
target_ntype (str, optional) – The node type to add split mask for.
-
train_idx
¶ An 1-D integer tensor of training node IDs.
- Type
Tensor
-
val_idx
¶ An 1-D integer tensor of validation node IDs.
- Type
Tensor
-
test_idx
¶ An 1-D integer tensor of test node IDs.
- Type
Tensor
Examples
>>> ds = dgl.data.AmazonCoBuyComputerDataset() >>> print(ds) Dataset("amazon_co_buy_computer", num_graphs=1, save_path=...) >>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1]) >>> print(new_ds) Dataset("amazon_co_buy_computer-as-nodepred", num_graphs=1, save_path=...) >>> print('train_mask' in new_ds[0].ndata) True