PubmedGraphDataset

class dgl.data.PubmedGraphDataset(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True, transform=None)[source]

Bases: dgl.data.citation_graph.CitationGraphDataset

Pubmed citation network dataset.

    Deprecated since version 0.5.0:
  • graph is deprecated, it is replaced by:

    >>> dataset = PubmedGraphDataset()
    >>> graph = dataset[0]
    
  • train_mask is deprecated, it is replaced by:

    >>> dataset = PubmedGraphDataset()
    >>> graph = dataset[0]
    >>> train_mask = graph.ndata['train_mask']
    
  • val_mask is deprecated, it is replaced by:

    >>> dataset = PubmedGraphDataset()
    >>> graph = dataset[0]
    >>> val_mask = graph.ndata['val_mask']
    
  • test_mask is deprecated, it is replaced by:

    >>> dataset = PubmedGraphDataset()
    >>> graph = dataset[0]
    >>> test_mask = graph.ndata['test_mask']
    
  • labels is deprecated, it is replaced by:

    >>> dataset = PubmedGraphDataset()
    >>> graph = dataset[0]
    >>> labels = graph.ndata['label']
    
  • feat is deprecated, it is replaced by:

    >>> dataset = PubmedGraphDataset()
    >>> graph = dataset[0]
    >>> feat = graph.ndata['feat']
    

Nodes mean scientific publications and edges mean citation relationships. Each node has a predefined feature with 500 dimensions. The dataset is designed for the node classification task. The task is to predict the category of certain publication.

Statistics:

  • Nodes: 19717

  • Edges: 88651

  • Number of Classes: 3

  • Label Split:

    • Train: 60

    • Valid: 500

    • Test: 1000

Parameters
  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: True.

  • reverse_edge (bool) – Whether to add reverse edges in graph. Default: True.

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_classes

Number of label classes

Type

int

graph

Graph structure

Type

networkx.DiGraph

train_mask

Mask of training nodes

Type

numpy.ndarray

val_mask

Mask of validation nodes

Type

numpy.ndarray

test_mask

Mask of test nodes

Type

numpy.ndarray

labels

Ground truth labels of each node

Type

numpy.ndarray

features

Node features

Type

Tensor

Notes

The node feature is row-normalized.

Examples

>>> dataset = PubmedGraphDataset()
>>> g = dataset[0]
>>> num_class = dataset.num_of_class
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
>>>
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
>>>
>>> # get labels
>>> label = g.ndata['label']
__getitem__(idx)[source]

Gets the graph object

Parameters

idx (int) – Item index, PubmedGraphDataset has only one graph object

Returns

graph structure, node features and labels.

  • ndata['train_mask']: mask for training node set

  • ndata['val_mask']: mask for validation node set

  • ndata['test_mask']: mask for test node set

  • ndata['feat']: node feature

  • ndata['label']: ground truth labels

Return type

dgl.DGLGraph

__len__()[source]

The number of graphs in the dataset.