class, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]


Reddit dataset for community detection (node classification)

This is a graph dataset from Reddit posts made in the month of September, 2014. The node label in this case is the community, or “subreddit”, that a post belongs to. The authors sampled 50 large communities and built a post-to-post graph, connecting posts if the same user comments on both. In total this dataset contains 232,965 posts with an average degree of 492. We use the first 20 days for training and the remaining days for testing (with 30% used for validation).



  • Nodes: 232,965

  • Edges: 114,615,892

  • Node feature size: 602

  • Number of training samples: 153,431

  • Number of validation samples: 23,831

  • Number of test samples: 55,703

  • self_loop (bool) – Whether load dataset with self loop connections. Default: False

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: True.

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.


Number of classes for each node




>>> data = RedditDataset()
>>> g = data[0]
>>> num_classes = data.num_classes
>>> # get node feature
>>> feat = g.ndata['feat']
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
>>> # get labels
>>> label = g.ndata['label']
>>> # Train, Validation and Test

Get graph by index


idx (int) – Item index


graph structure, node labels, node features and splitting masks:

  • ndata['label']: node label

  • ndata['feat']: node feature

  • ndata['train_mask']: mask for training node set

  • ndata['val_mask']: mask for validation node set

  • ndata['test_mask']: mask for test node set

Return type



Number of graphs in the dataset