NodeEmbedding

class dgl.nn.pytorch.sparse_emb.NodeEmbedding(num_embeddings, embedding_dim, name, init_func=None, device=None, partition=None)[source]

Bases: object

Class for storing node embeddings.

The class is optimized for training large-scale node embeddings. It updates the embedding in a sparse way and can scale to graphs with millions of nodes. It also supports partitioning to multiple GPUs (on a single machine) for more acceleration. It does not support partitioning across machines.

Currently, DGL provides two optimizers that work with this NodeEmbedding class: SparseAdagrad and SparseAdam.

The implementation is based on torch.distributed package. It depends on the pytorch default distributed process group to collect multi-process information and uses torch.distributed.TCPStore to share meta-data information across multiple gpu processes. It use the local address of ‘127.0.0.1:12346’ to initialize the TCPStore.

NOTE: The support of NodeEmbedding is experimental.

Parameters
  • num_embeddings (int) – The number of embeddings. Currently, the number of embeddings has to be the same as the number of nodes.

  • embedding_dim (int) – The dimension size of embeddings.

  • name (str) – The name of the embeddings. The name should uniquely identify the embeddings in the system.

  • init_func (callable, optional) – The function to create the initial data. If the init function is not provided, the values of the embeddings are initialized to zero.

  • device (th.device) – Device to store the embeddings on.

  • parittion (NDArrayPartition) – The partition to use to distributed the embeddings between processes.

Examples

Before launching multiple gpu processes

>>> def initializer(emb):
        th.nn.init.xavier_uniform_(emb)
        return emb

In each training process

>>> emb = dgl.nn.NodeEmbedding(g.number_of_nodes(), 10, 'emb', init_func=initializer)
>>> optimizer = dgl.optim.SparseAdam([emb], lr=0.001)
>>> for blocks in dataloader:
...     ...
...     feats = emb(nids, gpu_0)
...     loss = F.sum(feats + 1, 0)
...     loss.backward()
...     optimizer.step()