RedditDataset¶
-
class
dgl.data.
RedditDataset
(self_loop=False, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]¶ Bases:
dgl.data.dgl_dataset.DGLBuiltinDataset
Reddit dataset for community detection (node classification)
-
Deprecated since version 0.5.0:
graph
is deprecated, it is replaced by:>>> dataset = RedditDataset() >>> graph = dataset[0]
num_labels
is deprecated, it is replaced by:>>> dataset = RedditDataset() >>> num_classes = dataset.num_classes
train_mask
is deprecated, it is replaced by:>>> dataset = RedditDataset() >>> graph = dataset[0] >>> train_mask = graph.ndata['train_mask']
val_mask
is deprecated, it is replaced by:>>> dataset = RedditDataset() >>> graph = dataset[0] >>> val_mask = graph.ndata['val_mask']
test_mask
is deprecated, it is replaced by:>>> dataset = RedditDataset() >>> graph = dataset[0] >>> test_mask = graph.ndata['test_mask']
features
is deprecated, it is replaced by:>>> dataset = RedditDataset() >>> graph = dataset[0] >>> features = graph.ndata['feat']
labels
is deprecated, it is replaced by:>>> dataset = RedditDataset() >>> graph = dataset[0] >>> labels = graph.ndata['label']
This is a graph dataset from Reddit posts made in the month of September, 2014. The node label in this case is the community, or “subreddit”, that a post belongs to. The authors sampled 50 large communities and built a post-to-post graph, connecting posts if the same user comments on both. In total this dataset contains 232,965 posts with an average degree of 492. We use the first 20 days for training and the remaining days for testing (with 30% used for validation).
Reference: http://snap.stanford.edu/graphsage/
Statistics
Nodes: 232,965
Edges: 114,615,892
Node feature size: 602
Number of training samples: 153,431
Number of validation samples: 23,831
Number of test samples: 55,703
- Parameters
self_loop (bool) – Whether load dataset with self loop connections. Default: False
raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/
force_reload (bool) – Whether to reload the dataset. Default: False
verbose (bool) – Whether to print out progress information. Default: True.
transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
-
graph
¶ Graph of the dataset
- Type
-
train_mask
¶ Mask of training nodes
- Type
-
val_mask
¶ Mask of validation nodes
- Type
-
test_mask
¶ Mask of test nodes
- Type
-
features
¶ Node features
- Type
Tensor
-
labels
¶ Node labels
- Type
Tensor
Examples
>>> data = RedditDataset() >>> g = data[0] >>> num_classes = data.num_classes >>> >>> # get node feature >>> feat = g.ndata['feat'] >>> >>> # get data split >>> train_mask = g.ndata['train_mask'] >>> val_mask = g.ndata['val_mask'] >>> test_mask = g.ndata['test_mask'] >>> >>> # get labels >>> label = g.ndata['label'] >>> >>> # Train, Validation and Test
-
__getitem__
(idx)[source]¶ Get graph by index
- Parameters
idx (int) – Item index
- Returns
graph structure, node labels, node features and splitting masks:
ndata['label']
: node labelndata['feat']
: node featurendata['train_mask']
: mask for training node setndata['val_mask']
: mask for validation node setndata['test_mask']:
mask for test node set
- Return type