Source code for dgl.data.gindt

"""Datasets used in How Powerful Are Graph Neural Networks?
(chen jun)
Datasets include:
MUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, PTC, REDDITBINARY, REDDITMULTI5K
https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip
"""

import os
import numpy as np

from .. import backend as F

from .dgl_dataset import DGLBuiltinDataset
from .utils import loadtxt, save_graphs, load_graphs, save_info, load_info, download, extract_archive
from ..utils import retry_method_with_fix
from ..convert import graph as dgl_graph


[docs]class GINDataset(DGLBuiltinDataset): """Dataset Class for `How Powerful Are Graph Neural Networks? <https://arxiv.org/abs/1810.00826>`_. This is adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_. The class provides an interface for nine datasets used in the paper along with the paper-specific settings. The datasets are ``'MUTAG'``, ``'COLLAB'``, ``'IMDBBINARY'``, ``'IMDBMULTI'``, ``'NCI1'``, ``'PROTEINS'``, ``'PTC'``, ``'REDDITBINARY'``, ``'REDDITMULTI5K'``. If ``degree_as_nlabel`` is set to ``False``, then ``ndata['label']`` stores the provided node label, otherwise ``ndata['label']`` stores the node in-degrees. For graphs that have node attributes, ``ndata['attr']`` stores the node attributes. For graphs that have no attribute, ``ndata['attr']`` stores the corresponding one-hot encoding of ``ndata['label']``. Parameters --------- name: str dataset name, one of (``'MUTAG'``, ``'COLLAB'``, \ ``'IMDBBINARY'``, ``'IMDBMULTI'``, \ ``'NCI1'``, ``'PROTEINS'``, ``'PTC'``, \ ``'REDDITBINARY'``, ``'REDDITMULTI5K'``) self_loop: bool add self to self edge if true degree_as_nlabel: bool take node degree as label and feature if true transform: callable, optional A transform that takes in a :class:`~dgl.DGLGraph` object and returns a transformed version. The :class:`~dgl.DGLGraph` object will be transformed before every access. Examples -------- >>> data = GINDataset(name='MUTAG', self_loop=False) The dataset instance is an iterable >>> len(data) 188 >>> g, label = data[128] >>> g Graph(num_nodes=13, num_edges=26, ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(7,), dtype=torch.float32)} edata_schemes={}) >>> label tensor(1) Batch the graphs and labels for mini-batch training >>> graphs, labels = zip(*[data[i] for i in range(16)]) >>> batched_graphs = dgl.batch(graphs) >>> batched_labels = torch.tensor(labels) >>> batched_graphs Graph(num_nodes=330, num_edges=748, ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(7,), dtype=torch.float32)} edata_schemes={}) """ def __init__(self, name, self_loop, degree_as_nlabel=False, raw_dir=None, force_reload=False, verbose=False, transform=None): self._name = name # MUTAG gin_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip' self.ds_name = 'nig' self.self_loop = self_loop self.graphs = [] self.labels = [] # relabel self.glabel_dict = {} self.nlabel_dict = {} self.elabel_dict = {} self.ndegree_dict = {} # global num self.N = 0 # total graphs number self.n = 0 # total nodes number self.m = 0 # total edges number # global num of classes self.gclasses = 0 self.nclasses = 0 self.eclasses = 0 self.dim_nfeats = 0 # flags self.degree_as_nlabel = degree_as_nlabel self.nattrs_flag = False self.nlabels_flag = False super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel), raw_dir=raw_dir, force_reload=force_reload, verbose=verbose, transform=transform) @property def raw_path(self): return os.path.join(self.raw_dir, 'GINDataset') def download(self): r""" Automatically download data and extract it. """ zip_file_path = os.path.join(self.raw_dir, 'GINDataset.zip') download(self.url, path=zip_file_path) extract_archive(zip_file_path, self.raw_path)
[docs] def __len__(self): """Return the number of graphs in the dataset.""" return len(self.graphs)
[docs] def __getitem__(self, idx): """Get the idx-th sample. Parameters --------- idx : int The sample index. Returns ------- (:class:`dgl.Graph`, Tensor) The graph and its label. """ if self._transform is None: g = self.graphs[idx] else: g = self._transform(self.graphs[idx]) return g, self.labels[idx]
def _file_path(self): return os.path.join(self.raw_dir, "GINDataset", 'dataset', self.name, "{}.txt".format(self.name)) def process(self): """ Loads input dataset from dataset/NAME/NAME.txt file """ if self.verbose: print('loading data...') self.file = self._file_path() with open(self.file, 'r') as f: # line_1 == N, total number of graphs self.N = int(f.readline().strip()) for i in range(self.N): if (i + 1) % 10 == 0 and self.verbose is True: print('processing graph {}...'.format(i + 1)) grow = f.readline().strip().split() # line_2 == [n_nodes, l] is equal to # [node number of a graph, class label of a graph] n_nodes, glabel = [int(w) for w in grow] # relabel graphs if glabel not in self.glabel_dict: mapped = len(self.glabel_dict) self.glabel_dict[glabel] = mapped self.labels.append(self.glabel_dict[glabel]) g = dgl_graph(([], [])) g.add_nodes(n_nodes) nlabels = [] # node labels nattrs = [] # node attributes if it has m_edges = 0 for j in range(n_nodes): nrow = f.readline().strip().split() # handle edges and attributes(if has) tmp = int(nrow[1]) + 2 # tmp == 2 + #edges if tmp == len(nrow): # no node attributes nrow = [int(w) for w in nrow] elif tmp > len(nrow): nrow = [int(w) for w in nrow[:tmp]] nattr = [float(w) for w in nrow[tmp:]] nattrs.append(nattr) else: raise Exception('edge number is incorrect!') # relabel nodes if it has labels # if it doesn't have node labels, then every nrow[0]==0 if not nrow[0] in self.nlabel_dict: mapped = len(self.nlabel_dict) self.nlabel_dict[nrow[0]] = mapped nlabels.append(self.nlabel_dict[nrow[0]]) m_edges += nrow[1] g.add_edges(j, nrow[2:]) # add self loop if self.self_loop: m_edges += 1 g.add_edges(j, j) if (j + 1) % 10 == 0 and self.verbose is True: print( 'processing node {} of graph {}...'.format( j + 1, i + 1)) print('this node has {} edgs.'.format( nrow[1])) if nattrs != []: nattrs = np.stack(nattrs) g.ndata['attr'] = F.tensor(nattrs, F.float32) self.nattrs_flag = True g.ndata['label'] = F.tensor(nlabels) if len(self.nlabel_dict) > 1: self.nlabels_flag = True assert g.number_of_nodes() == n_nodes # update statistics of graphs self.n += n_nodes self.m += m_edges self.graphs.append(g) self.labels = F.tensor(self.labels) # if no attr if not self.nattrs_flag: if self.verbose: print('there are no node features in this dataset!') # generate node attr by node degree if self.degree_as_nlabel: if self.verbose: print('generate node features by node degree...') for g in self.graphs: # actually this label shouldn't be updated # in case users want to keep it # but usually no features means no labels, fine. g.ndata['label'] = g.in_degrees() # extracting unique node labels # in case the labels/degrees are not continuous number nlabel_set = set([]) for g in self.graphs: nlabel_set = nlabel_set.union( set([F.as_scalar(nl) for nl in g.ndata['label']])) nlabel_set = list(nlabel_set) is_label_valid = all([label in self.nlabel_dict for label in nlabel_set]) if is_label_valid and len(nlabel_set) == np.max(nlabel_set) + 1 and np.min(nlabel_set) == 0: # Note this is different from the author's implementation. In weihua916's implementation, # the labels are relabeled anyway. But here we didn't relabel it if the labels are contiguous # to make it consistent with the original dataset label2idx = self.nlabel_dict else: label2idx = { nlabel_set[i]: i for i in range(len(nlabel_set)) } # generate node attr by node label for g in self.graphs: attr = np.zeros(( g.number_of_nodes(), len(label2idx))) attr[range(g.number_of_nodes()), [label2idx[nl] for nl in F.asnumpy(g.ndata['label']).tolist()]] = 1 g.ndata['attr'] = F.tensor(attr, F.float32) # after load, get the #classes and #dim self.gclasses = len(self.glabel_dict) self.nclasses = len(self.nlabel_dict) self.eclasses = len(self.elabel_dict) self.dim_nfeats = len(self.graphs[0].ndata['attr'][0]) if self.verbose: print('Done.') print( """ -------- Data Statistics --------' #Graphs: %d #Graph Classes: %d #Nodes: %d #Node Classes: %d #Node Features Dim: %d #Edges: %d #Edge Classes: %d Avg. of #Nodes: %.2f Avg. of #Edges: %.2f Graph Relabeled: %s Node Relabeled: %s Degree Relabeled(If degree_as_nlabel=True): %s \n """ % ( self.N, self.gclasses, self.n, self.nclasses, self.dim_nfeats, self.m, self.eclasses, self.n / self.N, self.m / self.N, self.glabel_dict, self.nlabel_dict, self.ndegree_dict)) def save(self): graph_path = os.path.join( self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash)) info_path = os.path.join( self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash)) label_dict = {'labels': self.labels} info_dict = {'N': self.N, 'n': self.n, 'm': self.m, 'self_loop': self.self_loop, 'gclasses': self.gclasses, 'nclasses': self.nclasses, 'eclasses': self.eclasses, 'dim_nfeats': self.dim_nfeats, 'degree_as_nlabel': self.degree_as_nlabel, 'glabel_dict': self.glabel_dict, 'nlabel_dict': self.nlabel_dict, 'elabel_dict': self.elabel_dict, 'ndegree_dict': self.ndegree_dict} save_graphs(str(graph_path), self.graphs, label_dict) save_info(str(info_path), info_dict) def load(self): graph_path = os.path.join( self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash)) info_path = os.path.join( self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash)) graphs, label_dict = load_graphs(str(graph_path)) info_dict = load_info(str(info_path)) self.graphs = graphs self.labels = label_dict['labels'] self.N = info_dict['N'] self.n = info_dict['n'] self.m = info_dict['m'] self.self_loop = info_dict['self_loop'] self.gclasses = info_dict['gclasses'] self.nclasses = info_dict['nclasses'] self.eclasses = info_dict['eclasses'] self.dim_nfeats = info_dict['dim_nfeats'] self.glabel_dict = info_dict['glabel_dict'] self.nlabel_dict = info_dict['nlabel_dict'] self.elabel_dict = info_dict['elabel_dict'] self.ndegree_dict = info_dict['ndegree_dict'] self.degree_as_nlabel = info_dict['degree_as_nlabel'] def has_cache(self): graph_path = os.path.join( self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash)) info_path = os.path.join( self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash)) if os.path.exists(graph_path) and os.path.exists(info_path): return True return False