""" BitcoinOTC dataset for fraud detection """
import numpy as np
import os
import datetime
import gzip
import shutil
from .dgl_dataset import DGLBuiltinDataset
from .utils import download, makedirs, save_graphs, load_graphs, check_sha1
from ..convert import graph as dgl_graph
from .. import backend as F
[docs]class BitcoinOTCDataset(DGLBuiltinDataset):
r"""BitcoinOTC dataset for fraud detection
This is who-trusts-whom network of people who trade using Bitcoin on
a platform called Bitcoin OTC. Since Bitcoin users are anonymous,
there is a need to maintain a record of users' reputation to prevent
transactions with fraudulent and risky users.
Offical website: `<https://snap.stanford.edu/data/soc-sign-bitcoin-otc.html>`_
Bitcoin OTC dataset statistics:
- Nodes: 5,881
- Edges: 35,592
- Range of edge weight: -10 to +10
- Percentage of positive edges: 89%
Parameters
----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset.
Default: False
verbose: bool
Whether to print out progress information.
Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
graphs : list
A list of DGLGraph objects
is_temporal : bool
Indicate whether the graphs are temporal graphs
Raises
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> dataset = BitcoinOTCDataset()
>>> len(dataset)
136
>>> for g in dataset:
.... # get edge feature
.... edge_weights = g.edata['h']
.... # your code here
>>>
"""
_url = 'https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz'
_sha1_str = 'c14281f9e252de0bd0b5f1c6e2bae03123938641'
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(BitcoinOTCDataset, self).__init__(name='bitcoinotc',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
def download(self):
gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
download(self.url, path=gz_file_path)
if not check_sha1(gz_file_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.'
'The repo may be outdated or download may be incomplete. '
'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz'))
self._extract_gz(gz_file_path, self.raw_path)
def process(self):
filename = os.path.join(self.save_path, self.name + '.csv')
data = np.loadtxt(filename, delimiter=',').astype(np.int64)
data[:, 0:2] = data[:, 0:2] - data[:, 0:2].min()
delta = datetime.timedelta(days=14).total_seconds()
# The source code is not released, but the paper indicates there're
# totally 137 samples. The cutoff below has exactly 137 samples.
time_index = np.around(
(data[:, 3] - data[:, 3].min()) / delta).astype(np.int64)
self._graphs = []
for i in range(time_index.max()):
row_mask = time_index <= i
edges = data[row_mask][:, 0:2]
rate = data[row_mask][:, 2]
g = dgl_graph((edges[:, 0], edges[:, 1]))
g.edata['h'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64'])
self._graphs.append(g)
def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
return os.path.exists(graph_path)
def save(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
save_graphs(graph_path, self.graphs)
def load(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
self._graphs = load_graphs(graph_path)[0]
@property
def graphs(self):
return self._graphs
[docs] def __len__(self):
r""" Number of graphs in the dataset.
Return
-------
int
"""
return len(self.graphs)
[docs] def __getitem__(self, item):
r""" Get graph by index
Parameters
----------
item : int
Item index
Returns
-------
:class:`dgl.DGLGraph`
The graph contains:
- ``edata['h']`` : edge weights
"""
if self._transform is None:
return self.graphs[item]
else:
return self._transform(self.graphs[item])
@property
def is_temporal(self):
r""" Are the graphs temporal graphs
Returns
-------
bool
"""
return True
def _extract_gz(self, file, target_dir, overwrite=False):
if os.path.exists(target_dir) and not overwrite:
return
print('Extracting file to {}'.format(target_dir))
fname = os.path.basename(file)
makedirs(target_dir)
out_file_path = os.path.join(target_dir, fname[:-3])
with gzip.open(file, 'rb') as f_in:
with open(out_file_path, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
BitcoinOTC = BitcoinOTCDataset