"""Node embedding optimizers"""
import abc
from abc import abstractmethod
import torch as th
from ...cuda import nccl
from ...nn.pytorch import NodeEmbedding
from ...partition import NDArrayPartition
from ...utils import (
create_shared_mem_array,
gather_pinned_tensor_rows,
get_shared_mem_array,
pin_memory_inplace,
scatter_pinned_tensor_rows,
)
class SparseGradOptimizer(abc.ABC):
r"""The abstract sparse optimizer.
Note: dgl sparse optimizer only work with dgl.NodeEmbedding
Parameters
----------
params : list of NodeEmbedding
The list of NodeEmbeddings.
lr : float
The learning rate.
"""
def __init__(self, params, lr):
self._params = params
self._lr = lr
self._rank = None
self._world_size = None
self._shared_cache = {}
self._clean_grad = False
self._opt_meta = {}
self._comm = None
self._first_step = True
self._device = None
# hold released shared memory to let other process to munmap it first
# otherwise it will crash the training
self.shmem_buffer_holder = []
assert len(params) > 0, "Empty parameters"
# if we are using shared memory for communication
for emb in params:
assert isinstance(
emb, NodeEmbedding
), "DGL SparseOptimizer only supports dgl.nn.NodeEmbedding"
if self._rank is None:
self._rank = emb.rank
self._world_size = emb.world_size
else:
assert (
self._rank == emb.rank
), "MultiGPU rank for each embedding should be same."
assert (
self._world_size == emb.world_size
), "MultiGPU world_size for each embedding should be same."
assert not self._rank is None
assert not self._world_size is None
def step(self):
"""The step function.
The step function is invoked at the end of every batch to update embeddings
"""
# on the first step, check to see if the grads are on the GPU
if self._first_step:
for emb in self._params:
for _, data in emb._trace:
if data.grad.device.type == "cuda":
# create a communicator
if self._device:
assert (
self._device == data.grad.device
), "All gradients must be on the same device"
else:
self._device = data.grad.device
else:
assert (
not self._device
), "All gradients must be on the same device"
# distributed backend use nccl
if self._device and (
not th.distributed.is_initialized()
or th.distributed.get_backend() == "nccl"
):
# device is only set if the grads are on a GPU
self._comm_setup()
else:
self._shared_setup()
self._first_step = False
if self._comm:
self._comm_step()
else:
self._shared_step()
@abstractmethod
def setup(self, params):
"""This is function where subclasses can perform any setup they need
to. It will be called during the first step, and communicators or
shared memory will have been setup before this call.
Parameters
----------
params : list of NodeEmbedding
The list of NodeEmbeddings.
"""
def _comm_setup(self):
self._comm = True
def _shared_setup(self):
for emb in self._params:
emb_name = emb.name
if self._rank == 0: # the master gpu process
opt_meta = create_shared_mem_array(
emb_name + "_opt_meta",
(self._world_size, self._world_size),
th.int32,
).zero_()
if self._rank == 0:
emb.store.set(emb_name + "_opt_meta", emb_name)
self._opt_meta[emb_name] = opt_meta
elif self._rank > 0:
# receive
emb.store.wait([emb_name + "_opt_meta"])
opt_meta = get_shared_mem_array(
emb_name + "_opt_meta",
(self._world_size, self._world_size),
th.int32,
)
self._opt_meta[emb_name] = opt_meta
def _comm_step(self):
with th.no_grad():
idx_in = {}
grad_in = {}
for emb in self._params: # pylint: disable=too-many-nested-blocks
emb_name = emb.name
partition = emb.partition
if not partition:
# use default partitioning
partition = NDArrayPartition(
emb.num_embeddings,
self._world_size if self._world_size > 0 else 1,
mode="remainder",
)
# we need to combine gradients from multiple forward paths
if len(emb._trace) == 0:
idx = th.zeros((0,), dtype=th.long, device=self._device)
grad = th.zeros(
(0, emb.embedding_dim),
dtype=th.float32,
device=self._device,
)
elif len(emb._trace) == 1:
# the special case where we can use the tensors as is
# without any memcpy's
idx, grad = emb._trace[0]
grad = grad.grad.data
else:
idx = []
grad = []
for i, data in emb._trace:
idx.append(i)
grad.append(data.grad.data)
idx = th.cat(idx, dim=0)
grad = th.cat(grad, dim=0)
(
idx_in[emb_name],
grad_in[emb_name],
) = nccl.sparse_all_to_all_push(idx, grad, partition=partition)
if emb.partition:
# if the embedding is partitioned, map back to indexes
# into the local tensor
idx_in[emb_name] = partition.map_to_local(idx_in[emb_name])
if self._clean_grad:
# clean gradient track
for emb in self._params:
emb.reset_trace()
self._clean_grad = False
for emb in self._params:
emb_name = emb.name
idx = idx_in[emb_name]
grad = grad_in[emb_name]
self.update(idx, grad, emb)
def _shared_step(self):
with th.no_grad():
# Frequently alloc and free shared memory to hold intermediate tensor is expensive
# We cache shared memory buffers in shared_emb.
shared_emb = {emb.name: ([], []) for emb in self._params}
# Go through all sparse embeddings
for emb in self._params: # pylint: disable=too-many-nested-blocks
emb_name = emb.name
# we need to combine gradients from multiple forward paths
idx = []
grad = []
for i, data in emb._trace:
idx.append(i)
grad.append(data.grad.data)
# If the sparse embedding is not used in the previous forward step
# The idx and grad will be empty, initialize them as empty tensors to
# avoid crashing the optimizer step logic.
#
# Note: we cannot skip the gradient exchange and update steps as other
# working processes may send gradient update requests corresponding
# to certain embedding to this process.
idx = (
th.cat(idx, dim=0)
if len(idx) != 0
else th.zeros((0,), dtype=th.long, device=th.device("cpu"))
)
grad = (
th.cat(grad, dim=0)
if len(grad) != 0
else th.zeros(
(0, emb.embedding_dim),
dtype=th.float32,
device=th.device("cpu"),
)
)
device = grad.device
idx_dtype = idx.dtype
grad_dtype = grad.dtype
grad_dim = grad.shape[1]
if self._world_size > 1:
if emb_name not in self._shared_cache:
self._shared_cache[emb_name] = {}
# Each training process takes the resposibility of updating a range
# of node embeddings, thus we can parallel the gradient update.
# The overall progress includes:
# 1. In each training process:
# 1.a Deciding which process a node embedding belongs to according
# to the formula: process_id = node_idx mod num_of_process(N)
# 1.b Split the node index tensor and gradient tensor into N parts
# according to step 1.
# 1.c Write each node index sub-tensor and gradient sub-tensor into
# different DGL shared memory buffers.
# 2. Cross training process synchronization
# 3. In each traning process:
# 3.a Collect node index sub-tensors and gradient sub-tensors
# 3.b Do gradient update
# 4. Done
idx_split = th.remainder(idx, self._world_size).long()
for i in range(self._world_size):
mask = idx_split == i
idx_i = idx[mask]
grad_i = grad[mask]
if i == self._rank:
shared_emb[emb_name][0].append(idx_i)
shared_emb[emb_name][1].append(grad_i)
else:
# currently nccl does not support Alltoallv operation
# we need to use CPU shared memory to share gradient
# across processes
idx_i = idx_i.to(th.device("cpu"))
grad_i = grad_i.to(th.device("cpu"))
idx_shmem_name = "idx_{}_{}_{}".format(
emb_name, self._rank, i
)
grad_shmem_name = "grad_{}_{}_{}".format(
emb_name, self._rank, i
)
# Create shared memory to hold temporary index and gradient tensor for
# cross-process send and recv.
if (
idx_shmem_name
not in self._shared_cache[emb_name]
or self._shared_cache[emb_name][
idx_shmem_name
].shape[0]
< idx_i.shape[0]
):
if (
idx_shmem_name
in self._shared_cache[emb_name]
):
self.shmem_buffer_holder.append(
self._shared_cache[emb_name][
idx_shmem_name
]
)
self.shmem_buffer_holder.append(
self._shared_cache[emb_name][
grad_shmem_name
]
)
# The total number of buffers is the number of NodeEmbeddings *
# world_size * (world_size - 1). The minimun buffer size is 128.
#
# We extend the buffer by idx_i.shape[0] * 2 to avoid
# frequent shared memory allocation.
# The overall buffer cost will be smaller than three times
# the maximum memory requirement for sharing gradients.
buffer_size = (
128
if idx_i.shape[0] < 128
else idx_i.shape[0] * 2
)
idx_shmem = create_shared_mem_array(
"{}_{}".format(idx_shmem_name, buffer_size),
(buffer_size,),
idx_dtype,
)
grad_shmem = create_shared_mem_array(
"{}_{}".format(
grad_shmem_name, buffer_size
),
(buffer_size, grad_dim),
grad_dtype,
)
self._shared_cache[emb_name][
idx_shmem_name
] = idx_shmem
self._shared_cache[emb_name][
grad_shmem_name
] = grad_shmem
# Fill shared memory with temporal index tensor and gradient tensor
self._shared_cache[emb_name][idx_shmem_name][
: idx_i.shape[0]
] = idx_i
self._shared_cache[emb_name][grad_shmem_name][
: idx_i.shape[0]
] = grad_i
self._opt_meta[emb_name][self._rank][
i
] = idx_i.shape[0]
else:
shared_emb[emb_name][0].append(idx)
shared_emb[emb_name][1].append(grad)
# make sure the idx shape is passed to each process through opt_meta
if self._world_size > 1:
th.distributed.barrier()
for emb in self._params: # pylint: disable=too-many-nested-blocks
emb_name = emb.name
if self._world_size > 1:
# The first element in shared_emb[emb_name][0] is the local idx
device = shared_emb[emb_name][0][0].device
# gather gradients from all other processes
for i in range(self._world_size):
if i != self._rank:
idx_shmem_name = "idx_{}_{}_{}".format(
emb_name, i, self._rank
)
grad_shmem_name = "grad_{}_{}_{}".format(
emb_name, i, self._rank
)
size = self._opt_meta[emb_name][i][self._rank]
# Retrive shared memory holding the temporal index and gradient
# tensor that is sent to current training process
if (
idx_shmem_name
not in self._shared_cache[emb_name]
or self._shared_cache[emb_name][
idx_shmem_name
].shape[0]
< size
):
buffer_size = 128 if size < 128 else size * 2
idx_shmem = get_shared_mem_array(
"{}_{}".format(idx_shmem_name, buffer_size),
(buffer_size,),
idx_dtype,
)
grad_shmem = get_shared_mem_array(
"{}_{}".format(
grad_shmem_name, buffer_size
),
(buffer_size, grad_dim),
grad_dtype,
)
self._shared_cache[emb_name][
idx_shmem_name
] = idx_shmem
self._shared_cache[emb_name][
grad_shmem_name
] = grad_shmem
idx_i = self._shared_cache[emb_name][
idx_shmem_name
][:size]
grad_i = self._shared_cache[emb_name][
grad_shmem_name
][:size]
shared_emb[emb_name][0].append(
idx_i.to(device, non_blocking=True)
)
shared_emb[emb_name][1].append(
grad_i.to(device, non_blocking=True)
)
if self._clean_grad:
# clean gradient track
for emb in self._params:
emb.reset_trace()
self._clean_grad = False
for emb in self._params:
emb_name = emb.name
idx = th.cat(shared_emb[emb_name][0], dim=0)
grad = th.cat(shared_emb[emb_name][1], dim=0)
self.update(idx, grad, emb)
# synchronized gradient update
if self._world_size > 1:
th.distributed.barrier()
@abstractmethod
def update(self, idx, grad, emb):
"""Update embeddings in a sparse manner
Sparse embeddings are updated in mini batches. We maintain gradient states for
each embedding so they can be updated separately.
Parameters
----------
idx : tensor
Index of the embeddings to be updated.
grad : tensor
Gradient of each embedding.
emb : dgl.nn.NodeEmbedding
Sparse node embedding to update.
"""
def zero_grad(self):
"""clean grad cache"""
self._clean_grad = True
def state_dict(self, **kwargs): # pylint: disable=unused-argument
"""Return a copy of the whole optimizer states stored in CPU memory.
If this is a multi-processing instance, the states will be returned in
shared memory. If the underlying embedding is currently stored on
multiple GPUs, all processes must call this method in the same order.
NOTE: This method must be called by all processes sharing the
underlying embedding, or it may result in a deadlock.
Returns
-------
dictionary of optimizer states
The optimizer states stored in CPU memory.
"""
return {
"state": {
emb.name: emb._all_get_optm_state() for emb in self._params
},
"param_groups": self.param_groups,
}
def load_state_dict(
self, state_dict, **kwargs
): # pylint: disable=unused-argument
"""Load the optimizer states. This method must be called by all
processes sharing the underlying embedding with identical
:attr:`state_dict`.
NOTE: This method must be called by all processes sharing the
underlying embedding, or it may result in a deadlock.
Parameters
----------
state_dict : dictionary of optimizer states
The global states to pull values from.
"""
for emb in self._params:
emb._all_set_optm_state(state_dict["state"][emb.name])
self._set_param_groups(state_dict["param_groups"])
@property
@abstractmethod
def param_groups(self):
"""Emulate 'param_groups' of torch.optim.Optimizer.
Different from that, the returned 'param_groups' doesn't contain
parameters because getting the whole embedding is very expensive.
It contains other attributes, e.g., lr, eps, for debugging.
"""
@abstractmethod
def _set_param_groups(self, groups):
"""A helper method to load param_groups from saved state_dict."""
[docs]class SparseAdagrad(SparseGradOptimizer):
r"""Node embedding optimizer using the Adagrad algorithm.
This optimizer implements a sparse version of Adagrad algorithm for
optimizing :class:`dgl.nn.NodeEmbedding`. Being sparse means it only updates
the embeddings whose gradients have updates, which are usually a very
small portion of the total embeddings.
Adagrad maintains a :math:`G_{t,i,j}` for every parameter in the embeddings, where
:math:`G_{t,i,j}=G_{t-1,i,j} + g_{t,i,j}^2` and :math:`g_{t,i,j}` is the gradient of
the dimension :math:`j` of embedding :math:`i` at step :math:`t`.
NOTE: The support of sparse Adagrad optimizer is experimental.
Parameters
----------
params : list[dgl.nn.NodeEmbedding]
The list of dgl.nn.NodeEmbedding.
lr : float
The learning rate.
eps : float, Optional
The term added to the denominator to improve numerical stability
Default: 1e-10
Examples
--------
>>> def initializer(emb):
th.nn.init.xavier_uniform_(emb)
return emb
>>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer)
>>> optimizer = dgl.optim.SparseAdagrad([emb], lr=0.001)
>>> for blocks in dataloader:
... ...
... feats = emb(nids, gpu_0)
... loss = F.sum(feats + 1, 0)
... loss.backward()
... optimizer.step()
"""
def __init__(self, params, lr, eps=1e-10):
super(SparseAdagrad, self).__init__(params, lr)
self._eps = eps
# setup tensors for optimizer states
self.setup(self._params)
def setup(self, params):
# We need to register a state sum for each embedding in the kvstore.
for emb in params:
assert isinstance(
emb, NodeEmbedding
), "SparseAdagrad only supports dgl.nn.NodeEmbedding"
emb_name = emb.name
if th.device(emb.weight.device) == th.device("cpu"):
# if our embedding is on the CPU, our state also has to be
if self._rank < 0:
state = th.empty(
emb.weight.shape,
dtype=th.float32,
device=th.device("cpu"),
).zero_()
elif self._rank == 0:
state = create_shared_mem_array(
emb_name + "_state", emb.weight.shape, th.float32
).zero_()
if self._world_size > 1:
emb.store.set(emb_name + "_opt", emb_name)
elif self._rank > 0:
# receive
emb.store.wait([emb_name + "_opt"])
state = get_shared_mem_array(
emb_name + "_state", emb.weight.shape, th.float32
)
else:
# distributed state on on gpu
state = th.empty(
emb.weight.shape,
dtype=th.float32,
device=emb.weight.device,
).zero_()
emb.set_optm_state((state,))
def update(self, idx, grad, emb):
"""Update embeddings in a sparse manner
Sparse embeddings are updated in mini batches. We maintain gradient states for
each embedding so they can be updated separately.
Parameters
----------
idx : tensor
Index of the embeddings to be updated.
grad : tensor
Gradient of each embedding.
emb : dgl.nn.NodeEmbedding
Sparse embedding to update.
"""
eps = self._eps
clr = self._lr
# the update is non-linear so indices must be unique
grad_indices, inverse, cnt = th.unique(
idx, return_inverse=True, return_counts=True
)
grad_values = th.zeros(
(grad_indices.shape[0], grad.shape[1]), device=grad.device
)
grad_values.index_add_(0, inverse, grad)
grad_values = grad_values / cnt.unsqueeze(1)
grad_sum = grad_values * grad_values
(state,) = emb.optm_state
state_dev = state.device
state_idx = grad_indices.to(state_dev)
grad_state = state[state_idx].to(grad.device)
grad_state += grad_sum
state[state_idx] = grad_state.to(state_dev)
std_values = grad_state.add_(eps).sqrt_()
tmp = clr * grad_values / std_values
emb.weight[state_idx] -= tmp.to(state_dev)
@property
def param_groups(self):
"""Emulate 'param_groups' of torch.optim.Optimizer.
Different from that, the returned 'param_groups' doesn't contain
parameters because getting the whole embedding is very expensive.
It contains other attributes, e.g., lr, eps, for debugging.
"""
return [{"lr": self._lr, "eps": self._eps}]
def _set_param_groups(self, groups):
"""A helper method to load param_groups from saved state_dict."""
self._lr = groups[0]["lr"]
self._eps = groups[0]["eps"]
[docs]class SparseAdam(SparseGradOptimizer):
r"""Node embedding optimizer using the Adam algorithm.
This optimizer implements a sparse version of Adagrad algorithm for
optimizing :class:`dgl.nn.NodeEmbedding`. Being sparse means it only
updates the embeddings whose gradients have updates, which are usually
a very small portion of the total embeddings.
Adam maintains a :math:`Gm_{t,i,j}` and `Gp_{t,i,j}` for every parameter
in the embeddings, where
:math:`Gm_{t,i,j}=beta1 * Gm_{t-1,i,j} + (1-beta1) * g_{t,i,j}`,
:math:`Gp_{t,i,j}=beta2 * Gp_{t-1,i,j} + (1-beta2) * g_{t,i,j}^2`,
:math:`g_{t,i,j} = lr * Gm_{t,i,j} / (1 - beta1^t) / \sqrt{Gp_{t,i,j} / (1 - beta2^t)}` and
:math:`g_{t,i,j}` is the gradient of the dimension :math:`j` of embedding :math:`i`
at step :math:`t`.
NOTE: The support of sparse Adam optimizer is experimental.
Parameters
----------
params : list[dgl.nn.NodeEmbedding]
The list of dgl.nn.NodeEmbeddings.
lr : float
The learning rate.
betas : tuple[float, float], Optional
Coefficients used for computing running averages of gradient and its square.
Default: (0.9, 0.999)
eps : float, Optional
The term added to the denominator to improve numerical stability
Default: 1e-8
use_uva : bool, Optional
Whether to use pinned memory for storing 'mem' and 'power' parameters,
when the embedding is stored on the CPU. This will improve training
speed, but will require locking a large number of virtual memory pages.
For embeddings which are stored in GPU memory, this setting will have
no effect.
Default: True if the gradients are generated on the GPU, and False
if the gradients are on the CPU.
dtype : torch.dtype, Optional
The type to store optimizer state with. Default: th.float32.
Examples
--------
>>> def initializer(emb):
th.nn.init.xavier_uniform_(emb)
return emb
>>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer)
>>> optimizer = dgl.optim.SparseAdam([emb], lr=0.001)
>>> for blocks in dataloader:
... ...
... feats = emb(nids, gpu_0)
... loss = F.sum(feats + 1, 0)
... loss.backward()
... optimizer.step()
"""
def __init__(
self,
params,
lr,
betas=(0.9, 0.999),
eps=1e-08,
use_uva=None,
dtype=th.float32,
):
super(SparseAdam, self).__init__(params, lr)
self._lr = lr
self._beta1 = betas[0]
self._beta2 = betas[1]
self._eps = eps
self._use_uva = use_uva
self._nd_handle = {}
self._is_using_uva = {}
assert dtype in [th.float16, th.float32], (
"Unsupported dtype {}. Valid choices are th.float32 "
"and th.float32".format(dtype)
)
self._dtype = dtype
# setup tensors for optimizer states
self.setup(self._params)
def _setup_uva(self, name, mem, power):
self._is_using_uva[name] = True
mem_nd = pin_memory_inplace(mem)
power_nd = pin_memory_inplace(power)
self._nd_handle[name] = [mem_nd, power_nd]
def setup(self, params):
# We need to register a state sum for each embedding in the kvstore.
for emb in params:
assert isinstance(
emb, NodeEmbedding
), "SparseAdam only supports dgl.nn.NodeEmbedding"
emb_name = emb.name
self._is_using_uva[emb_name] = self._use_uva
if th.device(emb.weight.device) == th.device("cpu"):
# if our embedding is on the CPU, our state also has to be
if self._rank < 0:
state_step = th.empty(
(emb.weight.shape[0],),
dtype=th.int32,
device=th.device("cpu"),
).zero_()
state_mem = th.empty(
emb.weight.shape,
dtype=self._dtype,
device=th.device("cpu"),
).zero_()
state_power = th.empty(
emb.weight.shape,
dtype=self._dtype,
device=th.device("cpu"),
).zero_()
elif self._rank == 0:
state_step = create_shared_mem_array(
emb_name + "_step", (emb.weight.shape[0],), th.int32
).zero_()
state_mem = create_shared_mem_array(
emb_name + "_mem", emb.weight.shape, self._dtype
).zero_()
state_power = create_shared_mem_array(
emb_name + "_power", emb.weight.shape, self._dtype
).zero_()
if self._world_size > 1:
emb.store.set(emb_name + "_opt", emb_name)
elif self._rank > 0:
# receive
emb.store.wait([emb_name + "_opt"])
state_step = get_shared_mem_array(
emb_name + "_step", (emb.weight.shape[0],), th.int32
)
state_mem = get_shared_mem_array(
emb_name + "_mem", emb.weight.shape, self._dtype
)
state_power = get_shared_mem_array(
emb_name + "_power", emb.weight.shape, self._dtype
)
if self._is_using_uva[emb_name]:
# if use_uva has been explicitly set to true, otherwise
# wait until first step to decide
self._setup_uva(emb_name, state_mem, state_power)
else:
# make sure we don't use UVA when data is on the GPU
self._is_using_uva[emb_name] = False
# distributed state on on gpu
state_step = th.empty(
[emb.weight.shape[0]],
dtype=th.int32,
device=emb.weight.device,
).zero_()
state_mem = th.empty(
emb.weight.shape,
dtype=self._dtype,
device=emb.weight.device,
).zero_()
state_power = th.empty(
emb.weight.shape,
dtype=self._dtype,
device=emb.weight.device,
).zero_()
state = (state_step, state_mem, state_power)
emb.set_optm_state(state)
def update(self, idx, grad, emb):
"""Update embeddings in a sparse manner
Sparse embeddings are updated in mini batches. We maintain gradient states for
each embedding so they can be updated separately.
Parameters
----------
idx : tensor
Index of the embeddings to be updated.
grad : tensor
Gradient of each embedding.
emb : dgl.nn.NodeEmbedding
Sparse embedding to update.
"""
with th.no_grad():
state_step, state_mem, state_power = emb.optm_state
exec_dtype = grad.dtype
exec_dev = grad.device
state_dev = state_step.device
# whether or not we need to transfer data from the GPU to the CPU
# while updating the weights
is_d2h = state_dev.type == "cpu" and exec_dev.type == "cuda"
# only perform async copies cpu -> gpu, or gpu-> gpu, but block
# when copying to the cpu, so as to ensure the copy is finished
# before operating on the data on the cpu
state_block = is_d2h
if self._is_using_uva[emb.name] is None and is_d2h:
# we should use UVA going forward
self._setup_uva(emb.name, state_mem, state_power)
elif self._is_using_uva[emb.name] is None:
# we shouldn't use UVA going forward
self._is_using_uva[emb.name] = False
use_uva = self._is_using_uva[emb.name]
beta1 = self._beta1
beta2 = self._beta2
eps = self._eps
clr = self._lr
# There can be duplicated indices due to sampling.
# Thus unique them here and average the gradient here.
grad_indices, inverse, cnt = th.unique(
idx, return_inverse=True, return_counts=True
)
state_idx = grad_indices.to(state_dev)
state_step[state_idx] += 1
state_step = state_step[state_idx].to(exec_dev)
if use_uva:
orig_mem = gather_pinned_tensor_rows(state_mem, grad_indices)
orig_power = gather_pinned_tensor_rows(
state_power, grad_indices
)
else:
orig_mem = state_mem[state_idx].to(exec_dev)
orig_power = state_power[state_idx].to(exec_dev)
# convert to exec dtype
orig_mem = orig_mem.to(dtype=exec_dtype)
orig_power = orig_power.to(dtype=exec_dtype)
grad_values = th.zeros(
(grad_indices.shape[0], grad.shape[1]), device=exec_dev
)
grad_values.index_add_(0, inverse, grad)
grad_values = grad_values / cnt.unsqueeze(1)
grad_mem = grad_values
grad_power = grad_values * grad_values
update_mem = beta1 * orig_mem + (1.0 - beta1) * grad_mem
update_power = beta2 * orig_power + (1.0 - beta2) * grad_power
if use_uva:
scatter_pinned_tensor_rows(
state_mem, grad_indices, update_mem.to(dtype=self._dtype)
)
scatter_pinned_tensor_rows(
state_power,
grad_indices,
update_power.to(dtype=self._dtype),
)
else:
update_mem_dst = update_mem.to(dtype=self._dtype).to(
state_dev, non_blocking=True
)
update_power_dst = update_power.to(dtype=self._dtype).to(
state_dev, non_blocking=True
)
if state_block:
# use events to try and overlap CPU and GPU as much as possible
update_event = th.cuda.Event()
update_event.record()
update_mem_corr = update_mem / (
1.0 - th.pow(th.tensor(beta1, device=exec_dev), state_step)
).unsqueeze(1)
update_power_corr = update_power / (
1.0 - th.pow(th.tensor(beta2, device=exec_dev), state_step)
).unsqueeze(1)
std_values = (
clr * update_mem_corr / (th.sqrt(update_power_corr) + eps)
)
std_values_dst = std_values.to(state_dev, non_blocking=True)
if state_block:
std_event = th.cuda.Event()
std_event.record()
if not use_uva:
if state_block:
# wait for our transfers from exec_dev to state_dev to finish
# before we can use them
update_event.wait()
state_mem[state_idx] = update_mem_dst
state_power[state_idx] = update_power_dst
if state_block:
# wait for the transfer of std_values to finish before we
# can use it
std_event.wait()
emb.weight[state_idx] -= std_values_dst
@property
def param_groups(self):
"""Emulate 'param_groups' of torch.optim.Optimizer.
Different from that, the returned 'param_groups' doesn't contain
parameters because getting the whole embedding is very expensive.
It contains other attributes, e.g., lr, betas, eps, for debugging.
"""
return [
{
"lr": self._lr,
"betas": (self._beta1, self._beta2),
"eps": self._eps,
}
]
def _set_param_groups(self, groups):
"""A helper method to load param_groups from saved state_dict."""
self._lr = groups[0]["lr"]
self._beta1, self._beta2 = groups[0]["betas"]
self._eps = groups[0]["eps"]