Single Machine Multi-GPU Minibatch Node Classification

In this tutorial, you will learn how to use multiple GPUs in training a graph neural network (GNN) for node classification.

(Time estimate: 8 minutes)

This tutorial assumes that you have read the Training GNN with Neighbor Sampling for Node Classification tutorial. It also assumes that you know the basics of training general models with multi-GPU with DistributedDataParallel.

Note

See this tutorial from PyTorch for general multi-GPU training with DistributedDataParallel. Also, see the first section of the multi-GPU graph classification tutorial for an overview of using DistributedDataParallel with DGL.

Loading Dataset

OGB already prepared the data as a DGLGraph object. The following code is copy-pasted from the Training GNN with Neighbor Sampling for Node Classification tutorial.

import dgl
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv
from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
import sklearn.metrics

dataset = DglNodePropPredDataset('ogbn-arxiv')

graph, node_labels = dataset[0]
# Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph)
graph.ndata['label'] = node_labels[:, 0]

node_features = graph.ndata['feat']
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()

idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
valid_nids = idx_split['valid']
test_nids = idx_split['test']    # Test node IDs, not used in the tutorial though.

Out:

Downloading https://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip

  0%|          | 0/81 [00:00<?, ?it/s]
Downloaded 0.00 GB:   0%|          | 0/81 [00:00<?, ?it/s]
Downloaded 0.00 GB:   1%|1         | 1/81 [00:00<00:17,  4.50it/s]
Downloaded 0.00 GB:   1%|1         | 1/81 [00:00<00:17,  4.50it/s]
Downloaded 0.00 GB:   1%|1         | 1/81 [00:00<00:17,  4.50it/s]
Downloaded 0.00 GB:   1%|1         | 1/81 [00:00<00:17,  4.50it/s]
Downloaded 0.00 GB:   5%|4         | 4/81 [00:00<00:12,  6.03it/s]
Downloaded 0.00 GB:   5%|4         | 4/81 [00:00<00:12,  6.03it/s]
Downloaded 0.01 GB:   5%|4         | 4/81 [00:00<00:12,  6.03it/s]
Downloaded 0.01 GB:   5%|4         | 4/81 [00:00<00:12,  6.03it/s]
Downloaded 0.01 GB:   5%|4         | 4/81 [00:00<00:12,  6.03it/s]
Downloaded 0.01 GB:   5%|4         | 4/81 [00:00<00:12,  6.03it/s]
Downloaded 0.01 GB:  11%|#1        | 9/81 [00:00<00:08,  8.15it/s]
Downloaded 0.01 GB:  11%|#1        | 9/81 [00:00<00:08,  8.15it/s]
Downloaded 0.01 GB:  11%|#1        | 9/81 [00:00<00:08,  8.15it/s]
Downloaded 0.01 GB:  11%|#1        | 9/81 [00:00<00:08,  8.15it/s]
Downloaded 0.01 GB:  11%|#1        | 9/81 [00:00<00:08,  8.15it/s]
Downloaded 0.01 GB:  11%|#1        | 9/81 [00:00<00:08,  8.15it/s]
Downloaded 0.01 GB:  17%|#7        | 14/81 [00:00<00:06, 10.76it/s]
Downloaded 0.01 GB:  17%|#7        | 14/81 [00:00<00:06, 10.76it/s]
Downloaded 0.02 GB:  17%|#7        | 14/81 [00:00<00:06, 10.76it/s]
Downloaded 0.02 GB:  17%|#7        | 14/81 [00:00<00:06, 10.76it/s]
Downloaded 0.02 GB:  17%|#7        | 14/81 [00:00<00:06, 10.76it/s]
Downloaded 0.02 GB:  22%|##2       | 18/81 [00:00<00:04, 13.46it/s]
Downloaded 0.02 GB:  22%|##2       | 18/81 [00:00<00:04, 13.46it/s]
Downloaded 0.02 GB:  22%|##2       | 18/81 [00:00<00:04, 13.46it/s]
Downloaded 0.02 GB:  22%|##2       | 18/81 [00:00<00:04, 13.46it/s]
Downloaded 0.02 GB:  22%|##2       | 18/81 [00:00<00:04, 13.46it/s]
Downloaded 0.02 GB:  27%|##7       | 22/81 [00:00<00:03, 16.20it/s]
Downloaded 0.02 GB:  27%|##7       | 22/81 [00:00<00:03, 16.20it/s]
Downloaded 0.02 GB:  27%|##7       | 22/81 [00:00<00:03, 16.20it/s]
Downloaded 0.02 GB:  27%|##7       | 22/81 [00:00<00:03, 16.20it/s]
Downloaded 0.03 GB:  27%|##7       | 22/81 [00:00<00:03, 16.20it/s]
Downloaded 0.03 GB:  32%|###2      | 26/81 [00:00<00:02, 18.99it/s]
Downloaded 0.03 GB:  32%|###2      | 26/81 [00:00<00:02, 18.99it/s]
Downloaded 0.03 GB:  32%|###2      | 26/81 [00:00<00:02, 18.99it/s]
Downloaded 0.03 GB:  32%|###2      | 26/81 [00:01<00:02, 18.99it/s]
Downloaded 0.03 GB:  32%|###2      | 26/81 [00:01<00:02, 18.99it/s]
Downloaded 0.03 GB:  37%|###7      | 30/81 [00:01<00:02, 21.60it/s]
Downloaded 0.03 GB:  37%|###7      | 30/81 [00:01<00:02, 21.60it/s]
Downloaded 0.03 GB:  37%|###7      | 30/81 [00:01<00:02, 21.60it/s]
Downloaded 0.03 GB:  37%|###7      | 30/81 [00:01<00:02, 21.60it/s]
Downloaded 0.03 GB:  37%|###7      | 30/81 [00:01<00:02, 21.60it/s]
Downloaded 0.03 GB:  42%|####1     | 34/81 [00:01<00:02, 21.91it/s]
Downloaded 0.03 GB:  42%|####1     | 34/81 [00:01<00:02, 21.91it/s]
Downloaded 0.04 GB:  42%|####1     | 34/81 [00:01<00:02, 21.91it/s]
Downloaded 0.04 GB:  42%|####1     | 34/81 [00:01<00:02, 21.91it/s]
Downloaded 0.04 GB:  46%|####5     | 37/81 [00:01<00:01, 22.30it/s]
Downloaded 0.04 GB:  46%|####5     | 37/81 [00:01<00:01, 22.30it/s]
Downloaded 0.04 GB:  46%|####5     | 37/81 [00:01<00:01, 22.30it/s]
Downloaded 0.04 GB:  46%|####5     | 37/81 [00:01<00:01, 22.30it/s]
Downloaded 0.04 GB:  49%|####9     | 40/81 [00:01<00:01, 22.69it/s]
Downloaded 0.04 GB:  49%|####9     | 40/81 [00:01<00:01, 22.69it/s]
Downloaded 0.04 GB:  49%|####9     | 40/81 [00:01<00:01, 22.69it/s]
Downloaded 0.04 GB:  49%|####9     | 40/81 [00:01<00:01, 22.69it/s]
Downloaded 0.04 GB:  53%|#####3    | 43/81 [00:01<00:01, 23.05it/s]
Downloaded 0.04 GB:  53%|#####3    | 43/81 [00:01<00:01, 23.05it/s]
Downloaded 0.04 GB:  53%|#####3    | 43/81 [00:01<00:01, 23.05it/s]
Downloaded 0.04 GB:  53%|#####3    | 43/81 [00:01<00:01, 23.05it/s]
Downloaded 0.04 GB:  57%|#####6    | 46/81 [00:01<00:01, 23.45it/s]
Downloaded 0.05 GB:  57%|#####6    | 46/81 [00:01<00:01, 23.45it/s]
Downloaded 0.05 GB:  57%|#####6    | 46/81 [00:01<00:01, 23.45it/s]
Downloaded 0.05 GB:  57%|#####6    | 46/81 [00:01<00:01, 23.45it/s]
Downloaded 0.05 GB:  60%|######    | 49/81 [00:01<00:01, 23.77it/s]
Downloaded 0.05 GB:  60%|######    | 49/81 [00:01<00:01, 23.77it/s]
Downloaded 0.05 GB:  60%|######    | 49/81 [00:01<00:01, 23.77it/s]
Downloaded 0.05 GB:  60%|######    | 49/81 [00:01<00:01, 23.77it/s]
Downloaded 0.05 GB:  64%|######4   | 52/81 [00:01<00:01, 24.09it/s]
Downloaded 0.05 GB:  64%|######4   | 52/81 [00:02<00:01, 24.09it/s]
Downloaded 0.05 GB:  64%|######4   | 52/81 [00:02<00:01, 24.09it/s]
Downloaded 0.05 GB:  64%|######4   | 52/81 [00:02<00:01, 24.09it/s]
Downloaded 0.05 GB:  68%|######7   | 55/81 [00:02<00:01, 24.21it/s]
Downloaded 0.05 GB:  68%|######7   | 55/81 [00:02<00:01, 24.21it/s]
Downloaded 0.06 GB:  68%|######7   | 55/81 [00:02<00:01, 24.21it/s]
Downloaded 0.06 GB:  68%|######7   | 55/81 [00:02<00:01, 24.21it/s]
Downloaded 0.06 GB:  72%|#######1  | 58/81 [00:02<00:00, 24.36it/s]
Downloaded 0.06 GB:  72%|#######1  | 58/81 [00:02<00:00, 24.36it/s]
Downloaded 0.06 GB:  72%|#######1  | 58/81 [00:02<00:00, 24.36it/s]
Downloaded 0.06 GB:  72%|#######1  | 58/81 [00:02<00:00, 24.36it/s]
Downloaded 0.06 GB:  75%|#######5  | 61/81 [00:02<00:00, 24.55it/s]
Downloaded 0.06 GB:  75%|#######5  | 61/81 [00:02<00:00, 24.55it/s]
Downloaded 0.06 GB:  75%|#######5  | 61/81 [00:02<00:00, 24.55it/s]
Downloaded 0.06 GB:  75%|#######5  | 61/81 [00:02<00:00, 24.55it/s]
Downloaded 0.06 GB:  79%|#######9  | 64/81 [00:02<00:00, 24.95it/s]
Downloaded 0.06 GB:  79%|#######9  | 64/81 [00:02<00:00, 24.95it/s]
Downloaded 0.06 GB:  79%|#######9  | 64/81 [00:02<00:00, 24.95it/s]
Downloaded 0.07 GB:  79%|#######9  | 64/81 [00:02<00:00, 24.95it/s]
Downloaded 0.07 GB:  83%|########2 | 67/81 [00:02<00:00, 25.11it/s]
Downloaded 0.07 GB:  83%|########2 | 67/81 [00:02<00:00, 25.11it/s]
Downloaded 0.07 GB:  83%|########2 | 67/81 [00:02<00:00, 25.11it/s]
Downloaded 0.07 GB:  83%|########2 | 67/81 [00:02<00:00, 25.11it/s]
Downloaded 0.07 GB:  86%|########6 | 70/81 [00:02<00:00, 25.38it/s]
Downloaded 0.07 GB:  86%|########6 | 70/81 [00:02<00:00, 25.38it/s]
Downloaded 0.07 GB:  86%|########6 | 70/81 [00:02<00:00, 25.38it/s]
Downloaded 0.07 GB:  86%|########6 | 70/81 [00:02<00:00, 25.38it/s]
Downloaded 0.07 GB:  90%|######### | 73/81 [00:02<00:00, 25.53it/s]
Downloaded 0.07 GB:  90%|######### | 73/81 [00:02<00:00, 25.53it/s]
Downloaded 0.07 GB:  90%|######### | 73/81 [00:02<00:00, 25.53it/s]
Downloaded 0.07 GB:  90%|######### | 73/81 [00:02<00:00, 25.53it/s]
Downloaded 0.07 GB:  94%|#########3| 76/81 [00:02<00:00, 23.49it/s]
Downloaded 0.08 GB:  94%|#########3| 76/81 [00:03<00:00, 23.49it/s]
Downloaded 0.08 GB:  94%|#########3| 76/81 [00:03<00:00, 23.49it/s]
Downloaded 0.08 GB:  94%|#########3| 76/81 [00:03<00:00, 23.49it/s]
Downloaded 0.08 GB:  98%|#########7| 79/81 [00:03<00:00, 21.85it/s]
Downloaded 0.08 GB:  98%|#########7| 79/81 [00:03<00:00, 21.85it/s]
Downloaded 0.08 GB:  98%|#########7| 79/81 [00:03<00:00, 21.85it/s]
Downloaded 0.08 GB: 100%|##########| 81/81 [00:03<00:00, 25.87it/s]
Extracting dataset/arxiv.zip
Loading necessary files...
This might take a while.
Processing graphs...

  0%|          | 0/1 [00:00<?, ?it/s]
100%|##########| 1/1 [00:00<00:00, 8256.50it/s]
Converting graphs into DGL objects...

  0%|          | 0/1 [00:00<?, ?it/s]
100%|##########| 1/1 [00:00<00:00, 86.06it/s]
Saving...

Defining Model

The model will be again identical to the Training GNN with Neighbor Sampling for Node Classification tutorial.

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type='mean')
        self.h_feats = h_feats

    def forward(self, mfgs, x):
        h_dst = x[:mfgs[0].num_dst_nodes()]
        h = self.conv1(mfgs[0], (x, h_dst))
        h = F.relu(h)
        h_dst = h[:mfgs[1].num_dst_nodes()]
        h = self.conv2(mfgs[1], (h, h_dst))
        return h

Defining Training Procedure

The training procedure will be slightly different from what you saw previously, in the sense that you will need to

  • Initialize a distributed training context with torch.distributed.

  • Wrap your model with torch.nn.parallel.DistributedDataParallel.

  • Add a use_ddp=True argument to the DGL dataloader you wish to run together with DDP.

You will also need to wrap the training loop inside a function so that you can spawn subprocesses to run it.

def run(proc_id, devices):
    # Initialize distributed training context.
    dev_id = devices[proc_id]
    dist_init_method = 'tcp://{master_ip}:{master_port}'.format(master_ip='127.0.0.1', master_port='12345')
    if torch.cuda.device_count() < 1:
        device = torch.device('cpu')
        torch.distributed.init_process_group(
            backend='gloo', init_method=dist_init_method, world_size=len(devices), rank=proc_id)
    else:
        torch.cuda.set_device(dev_id)
        device = torch.device('cuda:' + str(dev_id))
        torch.distributed.init_process_group(
            backend='nccl', init_method=dist_init_method, world_size=len(devices), rank=proc_id)

    # Define training and validation dataloader, copied from the previous tutorial
    # but with one line of difference: use_ddp to enable distributed data parallel
    # data loading.
    sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
    train_dataloader = dgl.dataloading.NodeDataLoader(
        # The following arguments are specific to NodeDataLoader.
        graph,              # The graph
        train_nids,         # The node IDs to iterate over in minibatches
        sampler,            # The neighbor sampler
        device=device,      # Put the sampled MFGs on CPU or GPU
        use_ddp=True,       # Make it work with distributed data parallel
        # The following arguments are inherited from PyTorch DataLoader.
        batch_size=1024,    # Per-device batch size.
                            # The effective batch size is this number times the number of GPUs.
        shuffle=True,       # Whether to shuffle the nodes for every epoch
        drop_last=False,    # Whether to drop the last incomplete batch
        num_workers=0       # Number of sampler processes
    )
    valid_dataloader = dgl.dataloading.NodeDataLoader(
        graph, valid_nids, sampler,
        device=device,
        use_ddp=False,
        batch_size=1024,
        shuffle=False,
        drop_last=False,
        num_workers=0,
    )

    model = Model(num_features, 128, num_classes).to(device)
    # Wrap the model with distributed data parallel module.
    if device == torch.device('cpu'):
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=None, output_device=None)
    else:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], output_device=device)

    # Define optimizer
    opt = torch.optim.Adam(model.parameters())

    best_accuracy = 0
    best_model_path = './model.pt'

    # Copied from previous tutorial with changes highlighted.
    for epoch in range(10):
        train_dataloader.set_epoch(epoch)    # <--- necessary for dataloader with DDP.
        model.train()

        with tqdm.tqdm(train_dataloader) as tq:
            for step, (input_nodes, output_nodes, mfgs) in enumerate(tq):
                # feature copy from CPU to GPU takes place here
                inputs = mfgs[0].srcdata['feat']
                labels = mfgs[-1].dstdata['label']

                predictions = model(mfgs, inputs)

                loss = F.cross_entropy(predictions, labels)
                opt.zero_grad()
                loss.backward()
                opt.step()

                accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy())

                tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False)

        model.eval()

        # Evaluate on only the first GPU.
        if proc_id == 0:
            predictions = []
            labels = []
            with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
                for input_nodes, output_nodes, mfgs in tq:
                    inputs = mfgs[0].srcdata['feat']
                    labels.append(mfgs[-1].dstdata['label'].cpu().numpy())
                    predictions.append(model(mfgs, inputs).argmax(1).cpu().numpy())
                predictions = np.concatenate(predictions)
                labels = np.concatenate(labels)
                accuracy = sklearn.metrics.accuracy_score(labels, predictions)
                print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy))
                if best_accuracy < accuracy:
                    best_accuracy = accuracy
                    torch.save(model.state_dict(), best_model_path)

        # Note that this tutorial does not train the whole model to the end.
        break

Spawning Trainer Processes

A typical scenario for multi-GPU training with DDP is to replicate the model once per GPU, and spawn one trainer process per GPU.

PyTorch tutorials recommend using multiprocessing.spawn to spawn multiple processes. This however is undesirable for training node classification or link prediction models on a single large graph, especially on Linux. The reason is that a single large graph itself may take a lot of memory, and mp.spawn will duplicate all objects in the program, including the large graph. Consequently, the large graph will be duplicated as many times as the number of GPUs.

To alleviate the problem we recommend using multiprocessing.Process, which forks from the main process and allows sharing the same graph object to trainer processes via copy-on-write. This can greatly reduce the memory consumption.

Normally, DGL maintains only one sparse matrix representation (usually COO) for each graph, and will create new formats when some APIs are called for efficiency. For instance, calling in_degrees will create a CSC representation for the graph, and calling out_degrees will create a CSR representation. A consequence is that if a graph is shared to trainer processes via copy-on-write before having its CSC/CSR created, each trainer will create its own CSC/CSR replica once in_degrees or out_degrees is called. To avoid this, you need to create all sparse matrix representations beforehand using the create_formats_ method:

graph.create_formats_()

Then you can spawn the subprocesses to train with multiple GPUs.

Note

You will need to use dgl.multiprocessing instead of the Python multiprocessing package. dgl.multiprocessing is identical to Python’s built-in multiprocessing except that it handles the subtleties between forking and multithreading in Python.

# Say you have four GPUs.
num_gpus = 4
import dgl.multiprocessing as mp
devices = list(range(num_gpus))
procs = []
for proc_id in range(num_gpus):
    p = mp.Process(target=run, args=(proc_id, devices))
    p.start()
    procs.append(p)
for p in procs:
    p.join()

# Thumbnail credits: Stanford CS224W Notes
# sphinx_gallery_thumbnail_path = '_static/blitz_1_introduction.png'

Total running time of the script: ( 0 minutes 8.231 seconds)

Gallery generated by Sphinx-Gallery