TreeGridDatasetο
- class dgl.data.TreeGridDataset(tree_height=8, num_motifs=80, grid_size=3, perturb_ratio=0.1, seed=None, raw_dir=None, force_reload=False, verbose=True, transform=None)[source]ο
Bases:
DGLBuiltinDataset
TREE-GRIDS dataset from GNNExplainer: Generating Explanations for Graph Neural Networks
This is a synthetic dataset for node classification. It is generated by performing the following steps in order.
Construct a balanced binary tree as the base graph.
Construct a set of n-by-n grid motifs.
Attach the motifs to randomly selected nodes of the base graph.
Perturb the graph by adding random edges.
Generate constant feature for all nodes, which is 1.
Nodes in the tree belong to class 0 and nodes in grids belong to class 1.
- Parameters:
tree_height (int, optional) β Height of the balanced binary tree. Default: 8
num_motifs (int, optional) β Number of grid motifs to use. Default: 80
grid_size (int, optional) β The number of nodes in a grid motif will be grid_size ^ 2. Default: 3
perturb_ratio (float, optional) β Number of random edges to add in perturbation divided by the number of original edges in the graph. Default: 0.1
seed (integer, random_state, or None, optional) β Indicator of random number generation state. Default: None
raw_dir (str, optional) β Raw file directory to store the processed data. Default: ~/.dgl/
force_reload (bool, optional) β Whether to always generate the data from scratch rather than load a cached version. Default: False
verbose (bool, optional) β Whether to print progress information. Default: True
transform (callable, optional) β A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access. Default: None
Examples
>>> from dgl.data import TreeGridDataset >>> dataset = TreeGridDataset() >>> dataset.num_classes 2 >>> g = dataset[0] >>> label = g.ndata['label'] >>> feat = g.ndata['feat']