AmazonRatingsDataset

class dgl.data.AmazonRatingsDataset(raw_dir=None, force_reload=False, verbose=True, transform=None)[source]

Bases: HeterophilousGraphDataset

Amazon-ratings dataset from the β€˜A Critical Look at the Evaluation of GNNs under Heterophily: Are We Really Making Progress? <https://arxiv.org/abs/2302.11640>’__ paper.

This dataset is based on the Amazon product co-purchasing data. Nodes are products (books, music CDs, DVDs, VHS video tapes), and edges connect products that are frequently bought together. The task is to predict the average rating given to a product by reviewers. All possible rating values were grouped into five classes. Node features are the mean of word embeddings for words in the product description.

Statistics:

  • Nodes: 24492

  • Edges: 186100

  • Classes: 5

  • Node features: 300

  • 10 train/val/test splits

Parameters:
  • raw_dir (str, optional) – Raw file directory to store the processed data. Default: ~/.dgl/

  • force_reload (bool, optional) – Whether to re-download the data source. Default: False

  • verbose (bool, optional) – Whether to print progress information. Default: True

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access. Default: None

num_classes

Number of node classes

Type:

int

Examples

>>> from dgl.data import AmazonRatingsDataset
>>> dataset = AmazonRatingsDataset()
>>> g = dataset[0]
>>> num_classes = dataset.num_classes
>>> # get node features
>>> feat = g.ndata["feat"]
>>> # get the first data split
>>> train_mask = g.ndata["train_mask"][:, 0]
>>> val_mask = g.ndata["val_mask"][:, 0]
>>> test_mask = g.ndata["test_mask"][:, 0]
>>> # get labels
>>> label = g.ndata['label']
__getitem__(idx)

Gets the data object at index.

__len__()

The number of examples in the dataset.