TUDataset¶
-
class
dgl.data.
TUDataset
(name, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]¶ Bases:
dgl.data.dgl_dataset.DGLBuiltinDataset
TUDataset contains lots of graph kernel datasets for graph classification.
- Parameters
name (str) – Dataset Name, such as
ENZYMES
,DD
,COLLAB
,MUTAG
, can be the datasets name on https://chrsmrrs.github.io/datasets/docs/datasets/.transform (callable, optional) – A transform that takes in a
DGLGraph
object and returns a transformed version. TheDGLGraph
object will be transformed before every access.
Notes
IMPORTANT: Some of the datasets have duplicate edges exist in the graphs, e.g. the edges in
IMDB-BINARY
are all duplicated. DGL faithfully keeps the duplicates as per the original data. Other frameworks such as PyTorch Geometric removes the duplicates by default. You can remove the duplicate edges withdgl.to_simple()
.Graphs may have node labels, node attributes, edge labels, and edge attributes, varing from different dataset.
Labels are mapped to \(\lbrace 0,\cdots,n-1 \rbrace\) where \(n\) is the number of labels (some datasets have raw labels \(\lbrace -1, 1 \rbrace\) which will be mapped to \(\lbrace 0, 1 \rbrace\)). In previous versions, the minimum label was added so that \(\lbrace -1, 1 \rbrace\) was mapped to \(\lbrace 0, 2 \rbrace\).
The dataset sorts graphs by their labels. Shuffle is preferred before manual train/val split.
Examples
>>> data = TUDataset('DD')
The dataset instance is an iterable
>>> len(data) 188 >>> g, label = data[1024] >>> g Graph(num_nodes=88, num_edges=410, ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'node_labels': Scheme(shape=(1,), dtype=torch.int64)} edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}) >>> label tensor([1])
Batch the graphs and labels for mini-batch training
>>> graphs, labels = zip(*[data[i] for i in range(16)]) >>> batched_graphs = dgl.batch(graphs) >>> batched_labels = torch.tensor(labels) >>> batched_graphs Graph(num_nodes=9539, num_edges=47382, ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)} edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
-
__getitem__
(idx)[source]¶ Get the idx-th sample.
- Parameters
idx (int) – The sample index.
- Returns
Graph with node feature stored in
feat
field and node label innode_label
if available. And its label.- Return type
(
dgl.DGLGraph
, Tensor)