dgl.optim

dgl sparse optimizer for pytorch.

Node embedding optimizer

class dgl.optim.pytorch.SparseAdagrad(params, lr, eps=1e-10)[source]

Node embedding optimizer using the Adagrad algorithm.

This optimizer implements a sparse version of Adagrad algorithm for optimizing dgl.nn.NodeEmbedding. Being sparse means it only updates the embeddings whose gradients have updates, which are usually a very small portion of the total embeddings.

Adagrad maintains a \(G_{t,i,j}\) for every parameter in the embeddings, where \(G_{t,i,j}=G_{t-1,i,j} + g_{t,i,j}^2\) and \(g_{t,i,j}\) is the gradient of the dimension \(j\) of embedding \(i\) at step \(t\).

NOTE: The support of sparse Adagrad optimizer is experimental.

Parameters:
  • params (list[dgl.nn.NodeEmbedding]) – The list of dgl.nn.NodeEmbedding.

  • lr (float) – The learning rate.

  • eps (float, Optional) – The term added to the denominator to improve numerical stability Default: 1e-10

Examples

>>> def initializer(emb):
        th.nn.init.xavier_uniform_(emb)
        return emb
>>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer)
>>> optimizer = dgl.optim.SparseAdagrad([emb], lr=0.001)
>>> for blocks in dataloader:
...     ...
...     feats = emb(nids, gpu_0)
...     loss = F.sum(feats + 1, 0)
...     loss.backward()
...     optimizer.step()
class dgl.optim.pytorch.SparseAdam(params, lr, betas=(0.9, 0.999), eps=1e-08, use_uva=None, dtype=torch.float32)[source]

Node embedding optimizer using the Adam algorithm.

This optimizer implements a sparse version of Adagrad algorithm for optimizing dgl.nn.NodeEmbedding. Being sparse means it only updates the embeddings whose gradients have updates, which are usually a very small portion of the total embeddings.

Adam maintains a \(Gm_{t,i,j}\) and Gp_{t,i,j} for every parameter in the embeddings, where \(Gm_{t,i,j}=beta1 * Gm_{t-1,i,j} + (1-beta1) * g_{t,i,j}\), \(Gp_{t,i,j}=beta2 * Gp_{t-1,i,j} + (1-beta2) * g_{t,i,j}^2\), \(g_{t,i,j} = lr * Gm_{t,i,j} / (1 - beta1^t) / \sqrt{Gp_{t,i,j} / (1 - beta2^t)}\) and \(g_{t,i,j}\) is the gradient of the dimension \(j\) of embedding \(i\) at step \(t\).

NOTE: The support of sparse Adam optimizer is experimental.

Parameters:
  • params (list[dgl.nn.NodeEmbedding]) – The list of dgl.nn.NodeEmbeddings.

  • lr (float) – The learning rate.

  • betas (tuple[float, float], Optional) – Coefficients used for computing running averages of gradient and its square. Default: (0.9, 0.999)

  • eps (float, Optional) – The term added to the denominator to improve numerical stability Default: 1e-8

  • use_uva (bool, Optional) – Whether to use pinned memory for storing β€˜mem’ and β€˜power’ parameters, when the embedding is stored on the CPU. This will improve training speed, but will require locking a large number of virtual memory pages. For embeddings which are stored in GPU memory, this setting will have no effect. Default: True if the gradients are generated on the GPU, and False if the gradients are on the CPU.

  • dtype (torch.dtype, Optional) – The type to store optimizer state with. Default: th.float32.

Examples

>>> def initializer(emb):
        th.nn.init.xavier_uniform_(emb)
        return emb
>>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer)
>>> optimizer = dgl.optim.SparseAdam([emb], lr=0.001)
>>> for blocks in dataloader:
...     ...
...     feats = emb(nids, gpu_0)
...     loss = F.sum(feats + 1, 0)
...     loss.backward()
...     optimizer.step()