"""RDF datasets
Datasets from "A Collection of Benchmark Datasets for
Systematic Evaluations of Machine Learning on
the Semantic Web"
"""
import os
from collections import OrderedDict
import itertools
import abc
import re
import networkx as nx
import numpy as np
import dgl
import dgl.backend as F
from .dgl_dataset import DGLBuiltinDataset
from .utils import save_graphs, load_graphs, save_info, load_info, _get_dgl_url
from .utils import generate_mask_tensor, idx2mask
__all__ = ['AIFBDataset', 'MUTAGDataset', 'BGSDataset', 'AMDataset']
# Dictionary for renaming reserved node/edge type names to the ones
# that are allowed by nn.Module.
RENAME_DICT = {
'type' : 'rdftype',
'rev-type' : 'rev-rdftype',
}
class Entity:
"""Class for entities
Parameters
----------
id : str
ID of this entity
cls : str
Type of this entity
"""
def __init__(self, e_id, cls):
self.id = e_id
self.cls = cls
def __str__(self):
return '{}/{}'.format(self.cls, self.id)
class Relation:
"""Class for relations
Parameters
----------
cls : str
Type of this relation
"""
def __init__(self, cls):
self.cls = cls
def __str__(self):
return str(self.cls)
class RDFGraphDataset(DGLBuiltinDataset):
"""Base graph dataset class from RDF tuples.
To derive from this, implement the following abstract methods:
* ``parse_entity``
* ``parse_relation``
* ``process_tuple``
* ``process_idx_file_line``
* ``predict_category``
Preprocessed graph and other data will be cached in the download folder
to speedup data loading.
The dataset should contain a "trainingSet.tsv" and a "testSet.tsv" file
for training and testing samples.
Attributes
----------
num_classes : int
Number of classes to predict
predict_category : str
The entity category (node type) that has labels for prediction
Parameters
----------
name : str
Name of the dataset
url : str or path
URL to download the raw dataset.
predict_category : str
Predict category.
print_every : int, optional
Preprocessing log for every X tuples.
insert_reverse : bool, optional
If true, add reverse edge and reverse relations to the final graph.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool, optional
If true, force load and process from raw data. Ignore cached pre-processed data.
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
def __init__(self, name, url, predict_category,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
self._insert_reverse = insert_reverse
self._print_every = print_every
self._predict_category = predict_category
super(RDFGraphDataset, self).__init__(name, url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
def process(self):
raw_tuples = self.load_raw_tuples(self.raw_path)
self.process_raw_tuples(raw_tuples, self.raw_path)
def load_raw_tuples(self, root_path):
"""Loading raw RDF dataset
Parameters
----------
root_path : str
Root path containing the data
Returns
-------
Loaded rdf data
"""
import rdflib as rdf
raw_rdf_graphs = []
for _, filename in enumerate(os.listdir(root_path)):
fmt = None
if filename.endswith('nt'):
fmt = 'nt'
elif filename.endswith('n3'):
fmt = 'n3'
if fmt is None:
continue
g = rdf.Graph()
print('Parsing file %s ...' % filename)
g.parse(os.path.join(root_path, filename), format=fmt)
raw_rdf_graphs.append(g)
return itertools.chain(*raw_rdf_graphs)
def process_raw_tuples(self, raw_tuples, root_path):
"""Processing raw RDF dataset
Parameters
----------
raw_tuples:
Raw rdf tuples
root_path: str
Root path containing the data
"""
mg = nx.MultiDiGraph()
ent_classes = OrderedDict()
rel_classes = OrderedDict()
entities = OrderedDict()
src = []
dst = []
ntid = []
etid = []
sorted_tuples = []
for t in raw_tuples:
sorted_tuples.append(t)
sorted_tuples.sort()
for i, (sbj, pred, obj) in enumerate(sorted_tuples):
if self.verbose and i % self._print_every == 0:
print('Processed %d tuples, found %d valid tuples.' % (i, len(src)))
sbjent = self.parse_entity(sbj)
rel = self.parse_relation(pred)
objent = self.parse_entity(obj)
processed = self.process_tuple((sbj, pred, obj), sbjent, rel, objent)
if processed is None:
# ignored
continue
# meta graph
sbjclsid = _get_id(ent_classes, sbjent.cls)
objclsid = _get_id(ent_classes, objent.cls)
relclsid = _get_id(rel_classes, rel.cls)
mg.add_edge(sbjent.cls, objent.cls, key=rel.cls)
if self._insert_reverse:
mg.add_edge(objent.cls, sbjent.cls, key='rev-%s' % rel.cls)
# instance graph
src_id = _get_id(entities, str(sbjent))
if len(entities) > len(ntid): # found new entity
ntid.append(sbjclsid)
dst_id = _get_id(entities, str(objent))
if len(entities) > len(ntid): # found new entity
ntid.append(objclsid)
src.append(src_id)
dst.append(dst_id)
etid.append(relclsid)
src = np.asarray(src)
dst = np.asarray(dst)
ntid = np.asarray(ntid)
etid = np.asarray(etid)
ntypes = list(ent_classes.keys())
etypes = list(rel_classes.keys())
# add reverse edge with reverse relation
if self._insert_reverse:
if self.verbose:
print('Adding reverse edges ...')
newsrc = np.hstack([src, dst])
newdst = np.hstack([dst, src])
src = newsrc
dst = newdst
etid = np.hstack([etid, etid + len(etypes)])
etypes.extend(['rev-%s' % t for t in etypes])
hg = self.build_graph(mg, src, dst, ntid, etid, ntypes, etypes)
if self.verbose:
print('Load training/validation/testing split ...')
idmap = F.asnumpy(hg.nodes[self.predict_category].data[dgl.NID])
glb2lcl = {glbid : lclid for lclid, glbid in enumerate(idmap)}
def findidfn(ent):
if ent not in entities:
return None
else:
return glb2lcl[entities[ent]]
self._hg = hg
train_idx, test_idx, labels, num_classes = self.load_data_split(findidfn, root_path)
train_mask = idx2mask(train_idx, self._hg.number_of_nodes(self.predict_category))
test_mask = idx2mask(test_idx, self._hg.number_of_nodes(self.predict_category))
labels = F.tensor(labels, F.data_type_dict['int64'])
train_mask = generate_mask_tensor(train_mask)
test_mask = generate_mask_tensor(test_mask)
self._hg.nodes[self.predict_category].data['train_mask'] = train_mask
self._hg.nodes[self.predict_category].data['test_mask'] = test_mask
# TODO(minjie): Deprecate 'labels', use 'label' for consistency.
self._hg.nodes[self.predict_category].data['labels'] = labels
self._hg.nodes[self.predict_category].data['label'] = labels
self._num_classes = num_classes
def build_graph(self, mg, src, dst, ntid, etid, ntypes, etypes):
"""Build the graphs
Parameters
----------
mg: MultiDiGraph
Input graph
src: Numpy array
Source nodes
dst: Numpy array
Destination nodes
ntid: Numpy array
Node types for each node
etid: Numpy array
Edge types for each edge
ntypes: list
Node types
etypes: list
Edge types
Returns
-------
g: DGLGraph
"""
# create homo graph
if self.verbose:
print('Creating one whole graph ...')
g = dgl.graph((src, dst))
g.ndata[dgl.NTYPE] = F.tensor(ntid)
g.edata[dgl.ETYPE] = F.tensor(etid)
if self.verbose:
print('Total #nodes:', g.number_of_nodes())
print('Total #edges:', g.number_of_edges())
# rename names such as 'type' so that they an be used as keys
# to nn.ModuleDict
etypes = [RENAME_DICT.get(ty, ty) for ty in etypes]
mg_edges = mg.edges(keys=True)
mg = nx.MultiDiGraph()
for sty, dty, ety in mg_edges:
mg.add_edge(sty, dty, key=RENAME_DICT.get(ety, ety))
# convert to heterograph
if self.verbose:
print('Convert to heterograph ...')
hg = dgl.to_heterogeneous(g,
ntypes,
etypes,
metagraph=mg)
if self.verbose:
print('#Node types:', len(hg.ntypes))
print('#Canonical edge types:', len(hg.etypes))
print('#Unique edge type names:', len(set(hg.etypes)))
return hg
def load_data_split(self, ent2id, root_path):
"""Load data split
Parameters
----------
ent2id: func
A function mapping entity to id
root_path: str
Root path containing the data
Return
------
train_idx: Numpy array
Training set
test_idx: Numpy array
Testing set
labels: Numpy array
Labels
num_classes: int
Number of classes
"""
label_dict = {}
labels = np.zeros((self._hg.number_of_nodes(self.predict_category),)) - 1
train_idx = self.parse_idx_file(
os.path.join(root_path, 'trainingSet.tsv'),
ent2id, label_dict, labels)
test_idx = self.parse_idx_file(
os.path.join(root_path, 'testSet.tsv'),
ent2id, label_dict, labels)
train_idx = np.array(train_idx)
test_idx = np.array(test_idx)
labels = np.array(labels)
num_classes = len(label_dict)
return train_idx, test_idx, labels, num_classes
def parse_idx_file(self, filename, ent2id, label_dict, labels):
"""Parse idx files
Parameters
----------
filename: str
File to parse
ent2id: func
A function mapping entity to id
label_dict: dict
Map label to label id
labels: dict
Map entity id to label id
Return
------
idx: list
Entity idss
"""
idx = []
with open(filename, 'r') as f:
for i, line in enumerate(f):
if i == 0:
continue # first line is the header
sample, label = self.process_idx_file_line(line)
#person, _, label = line.strip().split('\t')
ent = self.parse_entity(sample)
entid = ent2id(str(ent))
if entid is None:
print('Warning: entity "%s" does not have any valid links associated. Ignored.' % str(ent))
else:
idx.append(entid)
lblid = _get_id(label_dict, label)
labels[entid] = lblid
return idx
def has_cache(self):
"""check if there is a processed data"""
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
if os.path.exists(graph_path) and \
os.path.exists(info_path):
return True
return False
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
save_graphs(str(graph_path), self._hg)
save_info(str(info_path), {'num_classes': self.num_classes,
'predict_category': self.predict_category})
def load(self):
"""load the graph list and the labels from disk"""
graph_path = os.path.join(self.save_path,
self.save_name + '.bin')
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
graphs, _ = load_graphs(str(graph_path))
info = load_info(str(info_path))
self._num_classes = info['num_classes']
self._predict_category = info['predict_category']
self._hg = graphs[0]
# For backward compatibility
if 'label' not in self._hg.nodes[self.predict_category].data:
self._hg.nodes[self.predict_category].data['label'] = \
self._hg.nodes[self.predict_category].data['labels']
def __getitem__(self, idx):
r"""Gets the graph object
"""
g = self._hg
if self._transform is not None:
g = self._transform(g)
return g
def __len__(self):
r"""The number of graphs in the dataset."""
return 1
@property
def save_name(self):
return self.name + '_dgl_graph'
@property
def predict_category(self):
return self._predict_category
@property
def num_classes(self):
return self._num_classes
@abc.abstractmethod
def parse_entity(self, term):
"""Parse one entity from an RDF term.
Return None if the term does not represent a valid entity and the
whole tuple should be ignored.
Parameters
----------
term : rdflib.term.Identifier
RDF term
Returns
-------
Entity or None
An entity.
"""
pass
@abc.abstractmethod
def parse_relation(self, term):
"""Parse one relation from an RDF term.
Return None if the term does not represent a valid relation and the
whole tuple should be ignored.
Parameters
----------
term : rdflib.term.Identifier
RDF term
Returns
-------
Relation or None
A relation
"""
pass
@abc.abstractmethod
def process_tuple(self, raw_tuple, sbj, rel, obj):
"""Process the tuple.
Return (Entity, Relation, Entity) tuple for as the final tuple.
Return None if the tuple should be ignored.
Parameters
----------
raw_tuple : tuple of rdflib.term.Identifier
(subject, predicate, object) tuple
sbj : Entity
Subject entity
rel : Relation
Relation
obj : Entity
Object entity
Returns
-------
(Entity, Relation, Entity)
The final tuple or None if should be ignored
"""
pass
@abc.abstractmethod
def process_idx_file_line(self, line):
"""Process one line of ``trainingSet.tsv`` or ``testSet.tsv``.
Parameters
----------
line : str
One line of the file
Returns
-------
(str, str)
One sample and its label
"""
pass
def _get_id(dict, key):
id = dict.get(key, None)
if id is None:
id = len(dict)
dict[key] = id
return id
[docs]class AIFBDataset(RDFGraphDataset):
r"""AIFB dataset for node classification task
AIFB DataSet is a Semantic Web (RDF) dataset used as a benchmark in
data mining. It records the organizational structure of AIFB at the
University of Karlsruhe.
AIFB dataset statistics:
- Nodes: 7262
- Edges: 48810 (including reverse edges)
- Target Category: Personen
- Number of Classes: 4
- Label Split:
- Train: 140
- Test: 36
Parameters
-----------
print_every : int
Preprocessing log for every X tuples. Default: 10000.
insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
num_classes : int
Number of classes to predict
predict_category : str
The entity category (node type) that has labels for prediction
Examples
--------
>>> dataset = dgl.data.rdf.AIFBDataset()
>>> graph = dataset[0]
>>> category = dataset.predict_category
>>> num_classes = dataset.num_classes
>>>
>>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data['test_mask']
>>> label = g.nodes[category].data['label']
"""
entity_prefix = 'http://www.aifb.uni-karlsruhe.de/'
relation_prefix = 'http://swrc.ontoware.org/'
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
import rdflib as rdf
self.employs = rdf.term.URIRef("http://swrc.ontoware.org/ontology#employs")
self.affiliation = rdf.term.URIRef("http://swrc.ontoware.org/ontology#affiliation")
url = _get_dgl_url('dataset/rdf/aifb-hetero.zip')
name = 'aifb-hetero'
predict_category = 'Personen'
super(AIFBDataset, self).__init__(name, url, predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
[docs] def __getitem__(self, idx):
r"""Gets the graph object
Parameters
-----------
idx: int
Item index, AIFBDataset has only one graph object
Return
-------
:class:`dgl.DGLGraph`
The graph contains:
- ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['label']``: node labels
"""
return super(AIFBDataset, self).__getitem__(idx)
[docs] def __len__(self):
r"""The number of graphs in the dataset.
Return
-------
int
"""
return super(AIFBDataset, self).__len__()
def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal):
return Entity(e_id=str(term), cls="_Literal")
if isinstance(term, rdf.BNode):
return None
entstr = str(term)
if entstr.startswith(self.entity_prefix):
sp = entstr.split('/')
return Entity(e_id=sp[5], cls=sp[3])
else:
return None
def parse_relation(self, term):
if term == self.employs or term == self.affiliation:
return None
relstr = str(term)
if relstr.startswith(self.relation_prefix):
return Relation(cls=relstr.split('/')[3])
else:
relstr = relstr.split('/')[-1]
return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj):
if sbj is None or rel is None or obj is None:
return None
return (sbj, rel, obj)
def process_idx_file_line(self, line):
person, _, label = line.strip().split('\t')
return person, label
[docs]class MUTAGDataset(RDFGraphDataset):
r"""MUTAG dataset for node classification task
Mutag dataset statistics:
- Nodes: 27163
- Edges: 148100 (including reverse edges)
- Target Category: d
- Number of Classes: 2
- Label Split:
- Train: 272
- Test: 68
Parameters
-----------
print_every : int
Preprocessing log for every X tuples. Default: 10000.
insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
num_classes : int
Number of classes to predict
predict_category : str
The entity category (node type) that has labels for prediction
graph : :class:`dgl.DGLGraph`
Graph structure
Examples
--------
>>> dataset = dgl.data.rdf.MUTAGDataset()
>>> graph = dataset[0]
>>> category = dataset.predict_category
>>> num_classes = dataset.num_classes
>>>
>>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data['test_mask']
>>> label = g.nodes[category].data['label']
"""
d_entity = re.compile("d[0-9]")
bond_entity = re.compile("bond[0-9]")
entity_prefix = 'http://dl-learner.org/carcinogenesis#'
relation_prefix = entity_prefix
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
import rdflib as rdf
self.is_mutagenic = rdf.term.URIRef("http://dl-learner.org/carcinogenesis#isMutagenic")
self.rdf_type = rdf.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type")
self.rdf_subclassof = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#subClassOf")
self.rdf_domain = rdf.term.URIRef("http://www.w3.org/2000/01/rdf-schema#domain")
url = _get_dgl_url('dataset/rdf/mutag-hetero.zip')
name = 'mutag-hetero'
predict_category = 'd'
super(MUTAGDataset, self).__init__(name, url, predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
[docs] def __getitem__(self, idx):
r"""Gets the graph object
Parameters
-----------
idx: int
Item index, MUTAGDataset has only one graph object
Return
-------
:class:`dgl.DGLGraph`
The graph contains:
- ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['label']``: node labels
"""
return super(MUTAGDataset, self).__getitem__(idx)
[docs] def __len__(self):
r"""The number of graphs in the dataset.
Return
-------
int
"""
return super(MUTAGDataset, self).__len__()
def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal):
return Entity(e_id=str(term), cls="_Literal")
elif isinstance(term, rdf.BNode):
return None
entstr = str(term)
if entstr.startswith(self.entity_prefix):
inst = entstr[len(self.entity_prefix):]
if self.d_entity.match(inst):
cls = 'd'
elif self.bond_entity.match(inst):
cls = 'bond'
else:
cls = None
return Entity(e_id=inst, cls=cls)
else:
return None
def parse_relation(self, term):
if term == self.is_mutagenic:
return None
relstr = str(term)
if relstr.startswith(self.relation_prefix):
cls = relstr[len(self.relation_prefix):]
return Relation(cls=cls)
else:
relstr = relstr.split('/')[-1]
return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj):
if sbj is None or rel is None or obj is None:
return None
if not raw_tuple[1].startswith('http://dl-learner.org/carcinogenesis#'):
obj.cls = 'SCHEMA'
if sbj.cls is None:
sbj.cls = 'SCHEMA'
if obj.cls is None:
obj.cls = rel.cls
assert sbj.cls is not None and obj.cls is not None
return (sbj, rel, obj)
def process_idx_file_line(self, line):
bond, _, label = line.strip().split('\t')
return bond, label
[docs]class BGSDataset(RDFGraphDataset):
r"""BGS dataset for node classification task
BGS namespace convention:
``http://data.bgs.ac.uk/(ref|id)/<Major Concept>/<Sub Concept>/INSTANCE``.
We ignored all literal nodes and the relations connecting them in the
output graph. We also ignored the relation used to mark whether a
term is CURRENT or DEPRECATED.
BGS dataset statistics:
- Nodes: 94806
- Edges: 672884 (including reverse edges)
- Target Category: Lexicon/NamedRockUnit
- Number of Classes: 2
- Label Split:
- Train: 117
- Test: 29
Parameters
-----------
print_every : int
Preprocessing log for every X tuples. Default: 10000.
insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
num_classes : int
Number of classes to predict
predict_category : str
All the labels of the entities in ``predict_category``
Examples
--------
>>> dataset = dgl.data.rdf.BGSDataset()
>>> graph = dataset[0]
>>> category = dataset.predict_category
>>> num_classes = dataset.num_classes
>>>
>>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data['test_mask']
>>> label = g.nodes[category].data['label']
"""
entity_prefix = 'http://data.bgs.ac.uk/'
status_prefix = 'http://data.bgs.ac.uk/ref/CurrentStatus'
relation_prefix = 'http://data.bgs.ac.uk/ref'
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
import rdflib as rdf
url = _get_dgl_url('dataset/rdf/bgs-hetero.zip')
name = 'bgs-hetero'
predict_category = 'Lexicon/NamedRockUnit'
self.lith = rdf.term.URIRef("http://data.bgs.ac.uk/ref/Lexicon/hasLithogenesis")
super(BGSDataset, self).__init__(name, url, predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
[docs] def __getitem__(self, idx):
r"""Gets the graph object
Parameters
-----------
idx: int
Item index, BGSDataset has only one graph object
Return
-------
:class:`dgl.DGLGraph`
The graph contains:
- ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['label']``: node labels
"""
return super(BGSDataset, self).__getitem__(idx)
[docs] def __len__(self):
r"""The number of graphs in the dataset.
Return
-------
int
"""
return super(BGSDataset, self).__len__()
def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal):
return None
elif isinstance(term, rdf.BNode):
return None
entstr = str(term)
if entstr.startswith(self.status_prefix):
return None
if entstr.startswith(self.entity_prefix):
sp = entstr.split('/')
if len(sp) != 7:
return None
# instance
cls = '%s/%s' % (sp[4], sp[5])
inst = sp[6]
return Entity(e_id=inst, cls=cls)
else:
return None
def parse_relation(self, term):
if term == self.lith:
return None
relstr = str(term)
if relstr.startswith(self.relation_prefix):
sp = relstr.split('/')
if len(sp) < 6:
return None
assert len(sp) == 6, relstr
cls = '%s/%s' % (sp[4], sp[5])
return Relation(cls=cls)
else:
relstr = relstr.replace('.', '_')
return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj):
if sbj is None or rel is None or obj is None:
return None
return (sbj, rel, obj)
def process_idx_file_line(self, line):
_, rock, label = line.strip().split('\t')
return rock, label
[docs]class AMDataset(RDFGraphDataset):
"""AM dataset. for node classification task
Namespace convention:
- Instance: ``http://purl.org/collections/nl/am/<type>-<id>``
- Relation: ``http://purl.org/collections/nl/am/<name>``
We ignored all literal nodes and the relations connecting them in the
output graph.
AM dataset statistics:
- Nodes: 881680
- Edges: 5668682 (including reverse edges)
- Target Category: proxy
- Number of Classes: 11
- Label Split:
- Train: 802
- Test: 198
Parameters
-----------
print_every : int
Preprocessing log for every X tuples. Default: 10000.
insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
num_classes : int
Number of classes to predict
predict_category : str
The entity category (node type) that has labels for prediction
Examples
--------
>>> dataset = dgl.data.rdf.AMDataset()
>>> graph = dataset[0]
>>> category = dataset.predict_category
>>> num_classes = dataset.num_classes
>>>
>>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data['test_mask']
>>> label = g.nodes[category].data['label']
"""
entity_prefix = 'http://purl.org/collections/nl/am/'
relation_prefix = entity_prefix
def __init__(self,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
import rdflib as rdf
self.objectCategory = rdf.term.URIRef("http://purl.org/collections/nl/am/objectCategory")
self.material = rdf.term.URIRef("http://purl.org/collections/nl/am/material")
url = _get_dgl_url('dataset/rdf/am-hetero.zip')
name = 'am-hetero'
predict_category = 'proxy'
super(AMDataset, self).__init__(name, url, predict_category,
print_every=print_every,
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
[docs] def __getitem__(self, idx):
r"""Gets the graph object
Parameters
-----------
idx: int
Item index, AMDataset has only one graph object
Return
-------
:class:`dgl.DGLGraph`
The graph contains:
- ``ndata['train_mask']``: mask for training node set
- ``ndata['test_mask']``: mask for testing node set
- ``ndata['label']``: node labels
"""
return super(AMDataset, self).__getitem__(idx)
[docs] def __len__(self):
r"""The number of graphs in the dataset.
Return
-------
int
"""
return super(AMDataset, self).__len__()
def parse_entity(self, term):
import rdflib as rdf
if isinstance(term, rdf.Literal):
return None
elif isinstance(term, rdf.BNode):
return Entity(e_id=str(term), cls='_BNode')
entstr = str(term)
if entstr.startswith(self.entity_prefix):
sp = entstr.split('/')
assert len(sp) == 7, entstr
spp = sp[6].split('-')
if len(spp) == 2:
# instance
cls, inst = spp
else:
cls = 'TYPE'
inst = spp
return Entity(e_id=inst, cls=cls)
else:
return None
def parse_relation(self, term):
if term == self.objectCategory or term == self.material:
return None
relstr = str(term)
if relstr.startswith(self.relation_prefix):
sp = relstr.split('/')
assert len(sp) == 7, relstr
cls = sp[6]
return Relation(cls=cls)
else:
relstr = relstr.replace('.', '_')
return Relation(cls=relstr)
def process_tuple(self, raw_tuple, sbj, rel, obj):
if sbj is None or rel is None or obj is None:
return None
return (sbj, rel, obj)
def process_idx_file_line(self, line):
proxy, _, label = line.strip().split('\t')
return proxy, label