RedditDataset(self_loop=False, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]¶
Reddit dataset for community detection (node classification)
This is a graph dataset from Reddit posts made in the month of September, 2014. The node label in this case is the community, or “subreddit”, that a post belongs to. The authors sampled 50 large communities and built a post-to-post graph, connecting posts if the same user comments on both. In total this dataset contains 232,965 posts with an average degree of 492. We use the first 20 days for training and the remaining days for testing (with 30% used for validation).
Node feature size: 602
Number of training samples: 153,431
Number of validation samples: 23,831
Number of test samples: 55,703
self_loop (bool) – Whether load dataset with self loop connections. Default: False
raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) – Whether to reload the dataset. Default: False
verbose (bool) – Whether to print out progress information. Default: True.
transform (callable, optional) – A transform that takes in a
DGLGraphobject and returns a transformed version. The
DGLGraphobject will be transformed before every access.
>>> data = RedditDataset() >>> g = data >>> num_classes = data.num_classes >>> >>> # get node feature >>> feat = g.ndata['feat'] >>> >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> test_mask = g.ndata['test_mask'] >>> >>> # get labels >>> label = g.ndata['label'] >>> >>> # Train, Validation and Test
Get graph by index
idx (int) – Item index
graph structure, node labels, node features and splitting masks:
ndata['label']: node label
ndata['feat']: node feature
ndata['train_mask']： mask for training node set
ndata['val_mask']: mask for validation node set
ndata['test_mask']:mask for test node set
- Return type