dgl.contrib.UnifiedTensor

UnifiedTensor enables direct CPU memory access from GPU. This feature is especially useful when GPUs need to access sparse data structure stored in CPU memory for several reasons (e.g., when node features do not fit in GPU memory). Without using this feature, sparsely structured data located in CPU memory must be gathered (or packed) before transferring it to the GPU memory because GPU DMA engines can only transfer data in a block granularity.

However, the gathering step wastes CPU cycles and increases the CPU to GPU data copy time. The goal of UnifiedTensor is to skip such CPU gathering step by letting GPUs to access even non-regular data in CPU memory. In a hardware-level, this function is enabled by NVIDIA GPUs’ unified virtual address (UVM) and zero-copy access capabilities. For those who wish to further extend the capability of UnifiedTensor may read the following paper (link) which explains the underlying mechanism of UnifiedTensor in detail.

Base Dataset Class

class dgl.contrib.UnifiedTensor(input, device)[source]

Class for storing unified tensor. Declaration of UnifiedTensor automatically pins the input tensor. Upon a successful declaration of UnifiedTensor, the target GPU device will have the address mapping of the input CPU tensor for zero-copy (direct) access over external interconnects (e.g., PCIe).

Parameters
  • input (Tensor) – Tensor which we want to convert into the unified tensor.

  • device (device) – GPU to create the address mapping of the input CPU tensor.

Examples

With a given CPU tensor feats, a new UnifiedTensor targetting a default GPU can be created as follows:

>>> feats = torch.rand((128,128))
>>> feats = dgl.contrib.UnifiedTensor(feats, device=torch.device('cuda'))

Now, the elements of the new tensor feats can be accessed with [] indexing. The context of the index tensor is a switch to trigger the zero-copy access from GPU. For example, to use the ordinary CPU-based data access, one can use the following method:

>>> idx = torch.Tensor([0,1,2])
>>> output = feats[idx]

Now, to use GPU to do a zero-copy access, do this:

>>> idx = torch.Tensor([0,1,2]).to('cuda')
>>> output = feats[idx]

For the multi-GPU operation, to allow multiple GPUs to access the original CPU tensor feats using UnifiedTensor, one can do the following:

>>> feats = torch.rand((128,128))
>>> feats_gpu0 = dgl.contrib.UnifiedTensor(feats, device=torch.device('cuda:0'))
>>> feats_gpu1 = dgl.contrib.UnifiedTensor(feats, device=torch.device('cuda:1'))
>>> feats_gpu2 = dgl.contrib.UnifiedTensor(feats, device=torch.device('cuda:2'))

Now, the cuda:0, cuda:1, and cuda:2 devices will be able to access the identical tensor located in the CPU memory using feats_gpu0, feats_gpu1, and feats_gpu2 tensors, respectively.

One can simply use following operations to slice the sub tensors into different GPU devices directly.

>>> feats_idx_gpu0 = torch.randint(128, 16, device='cuda:0')
>>> feats_idx_gpu1 = torch.randint(128, 16, device='cuda:1')
>>> feats_idx_gpu2 = torch.randint(128, 16, device='cuda:2')
>>> sub_feat_gpu0 = feats_gpu0[feats_idx_gpu0]
>>> sub_feat_gpu1 = feats_gpu1[feats_idx_gpu1]
>>> sub_feat_gpu2 = feats_gpu2[feats_idx_gpu2]

feats_gpu2 tensors, respectively.

__getitem__(key)[source]

Perform zero-copy access from GPU if the context of the key is cuda. Otherwise, just safely fallback to the backend specific indexing scheme.

Parameters

key (Tensor) – Tensor which contains the index ids