Source code for dgl.nn.pytorch.sparse_emb

"""Torch NodeEmbedding."""
from datetime import timedelta

import torch as th

from ...backend import pytorch as F
from ...cuda import nccl
from ...partition import NDArrayPartition
from ...utils import create_shared_mem_array, get_shared_mem_array

_STORE = None


[docs]class NodeEmbedding: # NodeEmbedding """Class for storing node embeddings. The class is optimized for training large-scale node embeddings. It updates the embedding in a sparse way and can scale to graphs with millions of nodes. It also supports partitioning to multiple GPUs (on a single machine) for more acceleration. It does not support partitioning across machines. Currently, DGL provides two optimizers that work with this NodeEmbedding class: ``SparseAdagrad`` and ``SparseAdam``. The implementation is based on torch.distributed package. It depends on the pytorch default distributed process group to collect multi-process information and uses ``torch.distributed.TCPStore`` to share meta-data information across multiple gpu processes. It use the local address of '127.0.0.1:12346' to initialize the TCPStore. NOTE: The support of NodeEmbedding is experimental. Parameters ---------- num_embeddings : int The number of embeddings. Currently, the number of embeddings has to be the same as the number of nodes. embedding_dim : int The dimension size of embeddings. name : str The name of the embeddings. The name should uniquely identify the embeddings in the system. init_func : callable, optional The function to create the initial data. If the init function is not provided, the values of the embeddings are initialized to zero. device : th.device Device to store the embeddings on. parittion : NDArrayPartition The partition to use to distributed the embeddings between processes. Examples -------- Before launching multiple gpu processes >>> def initializer(emb): th.nn.init.xavier_uniform_(emb) return emb In each training process >>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer) >>> optimizer = dgl.optim.SparseAdam([emb], lr=0.001) >>> for blocks in dataloader: ... ... ... feats = emb(nids, gpu_0) ... loss = F.sum(feats + 1, 0) ... loss.backward() ... optimizer.step() """ def __init__( self, num_embeddings, embedding_dim, name, init_func=None, device=None, partition=None, ): global _STORE if device is None: device = th.device("cpu") # Check whether it is multi-gpu training or not. if th.distributed.is_initialized(): rank = th.distributed.get_rank() world_size = th.distributed.get_world_size() else: rank = -1 world_size = 0 self._rank = rank self._world_size = world_size self._store = None self._comm = None self._partition = partition host_name = "127.0.0.1" port = 12346 if rank >= 0: # for multi-gpu training, setup a TCPStore for # embeding status synchronization across GPU processes if _STORE is None: _STORE = th.distributed.TCPStore( host_name, port, world_size, rank == 0, timedelta(seconds=10 * 60), ) self._store = _STORE # embeddings is stored in CPU memory. if th.device(device) == th.device("cpu"): if rank <= 0: emb = create_shared_mem_array( name, (num_embeddings, embedding_dim), th.float32 ) if init_func is not None: emb = init_func(emb) if rank == 0: # the master gpu process for _ in range(1, world_size): # send embs self._store.set(name, name) elif rank > 0: # receive self._store.wait([name]) emb = get_shared_mem_array( name, (num_embeddings, embedding_dim), th.float32 ) self._tensor = emb else: # embeddings is stored in GPU memory. self._comm = True if not self._partition: # for communication we need a partition self._partition = NDArrayPartition( num_embeddings, self._world_size if self._world_size > 0 else 1, mode="remainder", ) # create local tensors for the weights local_size = self._partition.local_size(max(self._rank, 0)) # TODO(dlasalle): support 16-bit/half embeddings emb = th.empty( [local_size, embedding_dim], dtype=th.float32, requires_grad=False, device=device, ) if init_func: emb = init_func(emb) self._tensor = emb self._num_embeddings = num_embeddings self._embedding_dim = embedding_dim self._name = name self._optm_state = None # track optimizer state self._trace = [] # track minibatch def __call__(self, node_ids, device=th.device("cpu")): """ node_ids : th.tensor Index of the embeddings to collect. device : th.device Target device to put the collected embeddings. """ if not self._comm: # For embeddings stored on the CPU. emb = self._tensor[node_ids].to(device) else: # For embeddings stored on the GPU. # The following method is designed to perform communication # across multiple GPUs and can handle situations where only one GPU # is present gracefully, a.k.a. self._world_size == 1 or # 0 (when th.distributed.is_initialized() is false). emb = nccl.sparse_all_to_all_pull( node_ids, self._tensor, self._partition ) emb = emb.to(device) if F.is_recording(): emb = F.attach_grad(emb) self._trace.append((node_ids.to(device), emb)) return emb @property def store(self): """Return torch.distributed.TCPStore for meta data sharing across processes. Returns ------- torch.distributed.TCPStore KVStore used for meta data sharing. """ return self._store @property def partition(self): """Return the partition identifying how the tensor is split across processes. Returns ------- String The mode. """ return self._partition @property def rank(self): """Return rank of current process. Returns ------- int The rank of current process. """ return self._rank @property def world_size(self): """Return world size of the pytorch distributed training env. Returns ------- int The world size of the pytorch distributed training env. """ return self._world_size @property def name(self): """Return the name of NodeEmbedding. Returns ------- str The name of NodeEmbedding. """ return self._name @property def num_embeddings(self): """Return the number of embeddings. Returns ------- int The number of embeddings. """ return self._num_embeddings @property def embedding_dim(self): """Return the dimension of embeddings. Returns ------- int The dimension of embeddings. """ return self._embedding_dim def set_optm_state(self, state): """Store the optimizer related state tensor. Parameters ---------- state : tuple of torch.Tensor Optimizer related state. """ self._optm_state = state @property def optm_state(self): """Return the optimizer related state tensor. Returns ------- tuple of torch.Tensor The optimizer related state. """ return self._optm_state @property def trace(self): """Return a trace of the indices of embeddings used in the training step(s). Returns ------- [torch.Tensor] The indices of embeddings used in the training step(s). """ return self._trace def reset_trace(self): """Clean up the trace of the indices of embeddings used in the training step(s). """ self._trace = [] @property def weight(self): """Return the tensor storing the node embeddings Returns ------- torch.Tensor The tensor storing the node embeddings """ return self._tensor def all_set_embedding(self, values): """Set the values of the embedding. This method must be called by all processes sharing the embedding with identical tensors for :attr:`values`. NOTE: This method must be called by all processes sharing the embedding, or it may result in a deadlock. Parameters ---------- values : Tensor The global tensor to pull values from. """ if self._partition: idxs = F.copy_to( self._partition.get_local_indices( max(self._rank, 0), ctx=F.context(self._tensor), ), F.context(values), ) self._tensor[:] = F.copy_to( F.gather_row(values, idxs), ctx=F.context(self._tensor) )[:] else: if self._rank == 0: self._tensor[:] = F.copy_to( values, ctx=F.context(self._tensor) )[:] if th.distributed.is_initialized(): th.distributed.barrier() def _all_get_tensor(self, shared_name, tensor, shape): """A helper function to get model-parallel tensors. This method must and only need to be called in multi-GPU DDP training. For now, it's only used in ``all_get_embedding`` and ``_all_get_optm_state``. """ # create a shared memory tensor if self._rank == 0: # root process creates shared memory val = create_shared_mem_array( shared_name, shape, tensor.dtype, ) self._store.set(shared_name, shared_name) else: self._store.wait([shared_name]) val = get_shared_mem_array( shared_name, shape, tensor.dtype, ) # need to map indices and slice into existing tensor idxs = self._partition.map_to_global( F.arange(0, tensor.shape[0], ctx=F.context(tensor)), self._rank, ).to(val.device) val[idxs] = tensor.to(val.device) self._store.delete_key(shared_name) # wait for all processes to finish th.distributed.barrier() return val def all_get_embedding(self): """Return a copy of the embedding stored in CPU memory. If this is a multi-processing instance, the tensor will be returned in shared memory. If the embedding is currently stored on multiple GPUs, all processes must call this method in the same order. NOTE: This method must be called by all processes sharing the embedding, or it may result in a deadlock. Returns ------- torch.Tensor The tensor storing the node embeddings. """ if self._partition: if self._world_size == 0: # non-multiprocessing return self._tensor.to(th.device("cpu")) else: return self._all_get_tensor( f"{self._name}_gather", self._tensor, (self._num_embeddings, self._embedding_dim), ) else: # already stored in CPU memory return self._tensor def _all_get_optm_state(self): """Return a copy of the whole optimizer states stored in CPU memory. If this is a multi-processing instance, the states will be returned in shared memory. If the embedding is currently stored on multiple GPUs, all processes must call this method in the same order. NOTE: This method must be called by all processes sharing the embedding, or it may result in a deadlock. Returns ------- tuple of torch.Tensor The optimizer states stored in CPU memory. """ if self._partition: if self._world_size == 0: # non-multiprocessing return tuple( state.to(th.device("cpu")) for state in self._optm_state ) else: return tuple( self._all_get_tensor( f"state_gather_{self._name}_{i}", state, (self._num_embeddings, *state.shape[1:]), ) for i, state in enumerate(self._optm_state) ) else: # already stored in CPU memory return self._optm_state def _all_set_optm_state(self, states): """Set the optimizer states of the embedding. This method must be called by all processes sharing the embedding with identical :attr:`states`. NOTE: This method must be called by all processes sharing the embedding, or it may result in a deadlock. Parameters ---------- states : tuple of torch.Tensor The global states to pull values from. """ if self._partition: idxs = F.copy_to( self._partition.get_local_indices( max(self._rank, 0), ctx=F.context(self._tensor) ), F.context(states[0]), ) for state, new_state in zip(self._optm_state, states): state[:] = F.copy_to( F.gather_row(new_state, idxs), ctx=F.context(self._tensor) )[:] else: # stored in CPU memory if self._rank <= 0: for state, new_state in zip(self._optm_state, states): state[:] = F.copy_to( new_state, ctx=F.context(self._tensor) )[:] if th.distributed.is_initialized(): th.distributed.barrier()