BGSDataset

class dgl.data.BGSDataset(print_every=10000, insert_reverse=True, raw_dir=None, force_reload=False, verbose=True, transform=None)[source]

Bases: dgl.data.rdf.RDFGraphDataset

BGS dataset for node classification task

BGS namespace convention: http://data.bgs.ac.uk/(ref|id)/<Major Concept>/<Sub Concept>/INSTANCE. We ignored all literal nodes and the relations connecting them in the output graph. We also ignored the relation used to mark whether a term is CURRENT or DEPRECATED.

BGS dataset statistics:

  • Nodes: 94806

  • Edges: 672884 (including reverse edges)

  • Target Category: Lexicon/NamedRockUnit

  • Number of Classes: 2

  • Label Split:

    • Train: 117

    • Test: 29

Parameters
  • print_every (int) – Preprocessing log for every X tuples. Default: 10000.

  • insert_reverse (bool) – If true, add reverse edge and reverse relations to the final graph. Default: True.

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: True.

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_classes

Number of classes to predict

Type

int

predict_category

All the labels of the entities in predict_category

Type

str

Examples

>>> dataset = dgl.data.rdf.BGSDataset()
>>> graph = dataset[0]
>>> category = dataset.predict_category
>>> num_classes = dataset.num_classes
>>>
>>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data['test_mask']
>>> label = g.nodes[category].data['label']
__getitem__(idx)[source]

Gets the graph object

Parameters

idx (int) – Item index, BGSDataset has only one graph object

Returns

The graph contains:

  • ndata['train_mask']: mask for training node set

  • ndata['test_mask']: mask for testing node set

  • ndata['label']: node labels

Return type

dgl.DGLGraph

__len__()[source]

The number of graphs in the dataset.

Returns

Return type

int