GraphDataLoader

class dgl.dataloading.GraphDataLoader(dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs)[source]

Bases: torch.utils.data.dataloader.DataLoader

Batched graph data loader.

PyTorch dataloader for batch-iterating over a set of graphs, generating the batched graph and corresponding label tensor (if provided) of the said minibatch.

Parameters
  • dataset (torch.utils.data.Dataset) – The dataset to load graphs from.

  • collate_fn (Function, default is None) – The customized collate function. Will use the default collate function if not given.

  • use_ddp (boolean, optional) –

    If True, tells the DataLoader to split the training set for each participating process appropriately using torch.utils.data.distributed.DistributedSampler.

    Overrides the sampler argument of torch.utils.data.DataLoader.

  • ddp_seed (int, optional) –

    The seed for shuffling the dataset in torch.utils.data.distributed.DistributedSampler.

    Only effective when use_ddp is True.

  • kwargs (dict) –

    Key-word arguments to be passed to the parent PyTorch torch.utils.data.DataLoader class. Common arguments are:

    • batch_size (int): The number of indices in each batch.

    • drop_last (bool): Whether to drop the last incomplete batch.

    • shuffle (bool): Whether to randomly shuffle the indices at each epoch.

Examples

To train a GNN for graph classification on a set of graphs in dataset:

>>> dataloader = dgl.dataloading.GraphDataLoader(
...     dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for batched_graph, labels in dataloader:
...     train_on(batched_graph, labels)

With Distributed Data Parallel

If you are using PyTorch’s distributed training (e.g. when using torch.nn.parallel.DistributedDataParallel), you can train the model by turning on the use_ddp option:

>>> dataloader = dgl.dataloading.GraphDataLoader(
...     dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
...     dataloader.set_epoch(epoch)
...     for batched_graph, labels in dataloader:
...         train_on(batched_graph, labels)