Source code for dgl.distributed.nn.pytorch.sparse_emb

"""Define sparse embedding and optimizer."""

import torch as th
from .... import backend as F
from .... import utils
from ...dist_tensor import DistTensor

[docs]class DistEmbedding: '''Distributed node embeddings. DGL provides a distributed embedding to support models that require learnable embeddings. DGL's distributed embeddings are mainly used for learning node embeddings of graph models. Because distributed embeddings are part of a model, they are updated by mini-batches. The distributed embeddings have to be updated by DGL's optimizers instead of the optimizers provided by the deep learning frameworks (e.g., Pytorch and MXNet). To support efficient training on a graph with many nodes, the embeddings support sparse updates. That is, only the embeddings involved in a mini-batch computation are updated. Please refer to `Distributed Optimizers <https://docs.dgl.ai/api/python/dgl.distributed.html# distributed-embedding-optimizer>`__ for available optimizers in DGL. Distributed embeddings are sharded and stored in a cluster of machines in the same way as :class:`dgl.distributed.DistTensor`, except that distributed embeddings are trainable. Because distributed embeddings are sharded in the same way as nodes and edges of a distributed graph, it is usually much more efficient to access than the sparse embeddings provided by the deep learning frameworks. Parameters ---------- num_embeddings : int The number of embeddings. Currently, the number of embeddings has to be the same as the number of nodes or the number of edges. embedding_dim : int The dimension size of embeddings. name : str, optional The name of the embeddings. The name can uniquely identify embeddings in a system so that another DistEmbedding object can referent to the same embeddings. init_func : callable, optional The function to create the initial data. If the init function is not provided, the values of the embeddings are initialized to zero. part_policy : PartitionPolicy, optional The partition policy that assigns embeddings to different machines in the cluster. Currently, it only supports node partition policy or edge partition policy. The system determines the right partition policy automatically. Examples -------- >>> def initializer(shape, dtype): arr = th.zeros(shape, dtype=dtype) arr.uniform_(-1, 1) return arr >>> emb = dgl.distributed.DistEmbedding(g.number_of_nodes(), 10, init_func=initializer) >>> optimizer = dgl.distributed.optim.SparseAdagrad([emb], lr=0.001) >>> for blocks in dataloader: ... feats = emb(nids) ... loss = F.sum(feats + 1, 0) ... loss.backward() ... optimizer.step() Note ---- When a ``DistEmbedding`` object is used when the deep learning framework is recording the forward computation, users have to invoke py:meth:`~dgl.distributed.optim.SparseAdagrad.step` afterwards. Otherwise, there will be some memory leak. ''' def __init__(self, num_embeddings, embedding_dim, name=None, init_func=None, part_policy=None): self._tensor = DistTensor((num_embeddings, embedding_dim), F.float32, name, init_func=init_func, part_policy=part_policy) self._trace = [] self._name = name self._num_embeddings = num_embeddings self._embedding_dim = embedding_dim # Check whether it is multi-gpu/distributed training or not if th.distributed.is_initialized(): self._rank = th.distributed.get_rank() self._world_size = th.distributed.get_world_size() # [TODO] The following code is clearly wrong but changing it to "raise DGLError" # actually fails unit test. ??? # else: # assert 'th.distributed should be initialized' self._optm_state = None # track optimizer state self._part_policy = part_policy def __call__(self, idx, device=th.device('cpu')): """ node_ids : th.tensor Index of the embeddings to collect. device : th.device Target device to put the collected embeddings. Returns ------- Tensor The requested node embeddings """ idx = utils.toindex(idx).tousertensor() emb = self._tensor[idx].to(device, non_blocking=True) if F.is_recording(): emb = F.attach_grad(emb) self._trace.append((idx.to(device, non_blocking=True), emb)) return emb def reset_trace(self): '''Reset the traced data. ''' self._trace = [] @property def part_policy(self): """Return the partition policy Returns ------- PartitionPolicy partition policy """ return self._part_policy @property def name(self): """Return the name of the embeddings Returns ------- str The name of the embeddings """ return self._tensor.tensor_name @property def data_name(self): """Return the data name of the embeddings Returns ------- str The data name of the embeddings """ return self._tensor._name @property def kvstore(self): """Return the kvstore client Returns ------- KVClient The kvstore client """ return self._tensor.kvstore @property def num_embeddings(self): """Return the number of embeddings Returns ------- int The number of embeddings """ return self._num_embeddings @property def embedding_dim(self): """Return the dimension of embeddings Returns ------- int The dimension of embeddings """ return self._embedding_dim @property def optm_state(self): """Return the optimizer related state tensor. Returns ------- tuple of torch.Tensor The optimizer related state. """ return self._optm_state @property def weight(self): """Return the tensor storing the node embeddings Returns ------- torch.Tensor The tensor storing the node embeddings """ return self._tensor