# pylint: disable=global-variable-undefined, invalid-name
"""Multiprocess dataloader for distributed training"""
from .. import backend as F
from .dist_context import get_sampler_pool
__all__ = ["DistDataLoader"]
DATALOADER_ID = 0
[docs]class DistDataLoader:
"""DGL customized multiprocessing dataloader.
DistDataLoader provides a similar interface to Pytorch's DataLoader to generate mini-batches
with multiprocessing. It utilizes the worker processes created by
:func:`dgl.distributed.initialize` to parallelize sampling.
Parameters
----------
dataset: a tensor
Tensors of node IDs or edge IDs.
batch_size: int
The number of samples per batch to load.
shuffle: bool, optional
Set to ``True`` to have the data reshuffled at every epoch (default: ``False``).
collate_fn: callable, optional
The function is typically used to sample neighbors of the nodes in a batch
or the endpoint nodes of the edges in a batch.
drop_last: bool, optional
Set to ``True`` to drop the last incomplete batch, if the dataset size is not
divisible by the batch size. If ``False`` and the size of dataset is not divisible
by the batch size, then the last batch will be smaller. (default: ``False``)
queue_size: int, optional
Size of multiprocessing queue
Examples
--------
>>> g = dgl.distributed.DistGraph('graph-name')
>>> def sample(seeds):
... seeds = th.LongTensor(np.asarray(seeds))
... frontier = dgl.distributed.sample_neighbors(g, seeds, 10)
... return dgl.to_block(frontier, seeds)
>>> dataloader = dgl.distributed.DistDataLoader(dataset=nodes, batch_size=1000,
collate_fn=sample, shuffle=True)
>>> for block in dataloader:
... feat = g.ndata['features'][block.srcdata[dgl.NID]]
... labels = g.ndata['labels'][block.dstdata[dgl.NID]]
... pred = model(block, feat)
Note
----
When performing DGL's distributed sampling with multiprocessing, users have to use this class
instead of Pytorch's DataLoader because DGL's RPC requires that all processes establish
connections with servers before invoking any DGL's distributed API. Therefore, this dataloader
uses the worker processes created in :func:`dgl.distributed.initialize`.
Note
----
This dataloader does not guarantee the iteration order. For example,
if dataset = [1, 2, 3, 4], batch_size = 2 and shuffle = False, the order of [1, 2]
and [3, 4] is not guaranteed.
"""
def __init__(
self,
dataset,
batch_size,
shuffle=False,
collate_fn=None,
drop_last=False,
queue_size=None,
):
self.pool, self.num_workers = get_sampler_pool()
if queue_size is None:
queue_size = self.num_workers * 4 if self.num_workers > 0 else 4
self.queue_size = queue_size # prefetch size
self.batch_size = batch_size
self.num_pending = 0
self.collate_fn = collate_fn
self.current_pos = 0
self.queue = [] # Only used when pool is None
self.drop_last = drop_last
self.recv_idxs = 0
self.shuffle = shuffle
self.is_closed = False
self.dataset = dataset
self.data_idx = F.arange(0, len(dataset))
self.expected_idxs = len(dataset) // self.batch_size
if not self.drop_last and len(dataset) % self.batch_size != 0:
self.expected_idxs += 1
# We need to have a unique ID for each data loader to identify itself
# in the sampler processes.
global DATALOADER_ID
self.name = "dataloader-" + str(DATALOADER_ID)
DATALOADER_ID += 1
if self.pool is not None:
self.pool.set_collate_fn(self.collate_fn, self.name)
def __del__(self):
# When the process exits, the process pool may have been closed. We should try
# and get the process pool again and see if we need to clean up the process pool.
self.pool, self.num_workers = get_sampler_pool()
if self.pool is not None:
self.pool.delete_collate_fn(self.name)
def __next__(self):
if self.pool is None:
num_reqs = 1
else:
num_reqs = self.queue_size - self.num_pending
for _ in range(num_reqs):
self._request_next_batch()
if self.recv_idxs < self.expected_idxs:
result = self._get_data_from_result_queue()
self.recv_idxs += 1
self.num_pending -= 1
return result
else:
assert self.num_pending == 0
raise StopIteration
def _get_data_from_result_queue(self, timeout=1800):
if self.pool is None:
ret = self.queue.pop(0)
else:
ret = self.pool.get_result(self.name, timeout=timeout)
return ret
def __iter__(self):
if self.shuffle:
self.data_idx = F.rand_shuffle(self.data_idx)
self.recv_idxs = 0
self.current_pos = 0
self.num_pending = 0
return self
def _request_next_batch(self):
next_data = self._next_data()
if next_data is None:
return
elif self.pool is not None:
self.pool.submit_task(self.name, next_data)
else:
result = self.collate_fn(next_data)
self.queue.append(result)
self.num_pending += 1
def _next_data(self):
if self.current_pos == len(self.dataset):
return None
end_pos = 0
if self.current_pos + self.batch_size > len(self.dataset):
if self.drop_last:
return None
else:
end_pos = len(self.dataset)
else:
end_pos = self.current_pos + self.batch_size
idx = self.data_idx[self.current_pos : end_pos].tolist()
ret = [self.dataset[i] for i in idx]
# Sharing large number of tensors between processes will consume too many
# file descriptors, so let's convert each tensor to scalar value beforehand.
if isinstance(ret[0], tuple):
ret = [(type, F.as_scalar(id)) for (type, id) in ret]
else:
ret = [F.as_scalar(id) for id in ret]
self.current_pos = end_pos
return ret