From ded3a1ea4c5b1007be02b9cd18e874c7023d35e3 Mon Sep 17 00:00:00 2001 From: imh966 Date: Fri, 20 Dec 2024 17:57:28 +0800 Subject: [PATCH] gradient accumulation --- .../distributed/batched_embedding_kernel.py | 97 ++++++++++++++++++- 1 file changed, 92 insertions(+), 5 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 861e6e868..18ec009c9 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -169,6 +169,25 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: return ssd_tbe_params +enable_gradient_accumulation = False + +def set_gradient_accumulation(enable: bool): + global enable_gradient_accumulation + enable_gradient_accumulation = enable + +class GradientAccumulationQueue: + def __init__(self) -> None: + self.inputs = [] + self.grads = [] + + def push(self, input, grad): + self.inputs.append(input) + self.grads.append(grad) + + def pop(self): + i, g = self.inputs, [x.grad for x in self.grads] + self.inputs, self.grads = [], [] + return i, g class KeyValueEmbeddingFusedOptimizer(FusedOptimizer): def __init__( @@ -625,6 +644,48 @@ def set_optimizer_step(self, step: int) -> None: self._emb_module.set_optimizer_step(step) +class GradAccEmbeddingFusedOptimizer(EmbeddingFusedOptimizer): + def __init__(self, queue: GradientAccumulationQueue, *args, **kwargs): + super().__init__(*args, **kwargs) + self._queue = queue + self._permute_indices = {} + + def step(self, closure: Any = None) -> None: + with torch.cuda.nvtx.range("grad acc optimizer step"): + inputs, grads = self._queue.pop() + if len(inputs) > 0: + for i in range(1, len(inputs)): + assert inputs[i].keys() == inputs[0].keys() + key_cnt, input_cnt = len(inputs[0].keys()), len(inputs) + ids = KeyedJaggedTensor.concat(inputs) + x = ids.values() + grad = torch.cat(grads, dim=0) + del inputs, grads + + ids._values = torch.arange(0, x.numel(), dtype=torch.int32, device='cuda') + indices_key = (key_cnt, input_cnt) + if indices_key in self._permute_indices: + permute_indices, permute_indices_tensor = self._permute_indices[indices_key] + else: + permute_indices = [] + for i in range(key_cnt): + for j in range(input_cnt): + permute_indices.append(j * key_cnt + i) + permute_indices_tensor = torch.tensor(permute_indices, dtype=torch.int, device=x.device) + self._permute_indices[indices_key] = (permute_indices, permute_indices_tensor) + + ids = ids.permute(permute_indices, indices_tensor=permute_indices_tensor) + x = torch.index_select(x, dim=0, index=ids.values()) + grad = torch.index_select(grad, dim=0, index=ids.values()) + + output = self._emb_module( + indices=x.long(), + offsets=ids.offsets().long(), + ) + torch.autograd.backward(output, grad_tensors=grad) + + super().step(closure) + def _gen_named_parameters_by_table_ssd( emb_module: SSDTableBatchedEmbeddingBags, table_name_to_count: Dict[str, int], @@ -1034,11 +1095,23 @@ def __init__( **fused_params, ) ) - self._optim: EmbeddingFusedOptimizer = EmbeddingFusedOptimizer( - config, - self._emb_module, - pg, - ) + global enable_gradient_accumulation + if enable_gradient_accumulation: + self._grad_acc_queue = GradientAccumulationQueue() + self._optim: EmbeddingFusedOptimizer = GradAccEmbeddingFusedOptimizer( + self._grad_acc_queue, + config, + self._emb_module, + pg, + ) + else: + self._grad_acc_queue = None + self._optim: EmbeddingFusedOptimizer = EmbeddingFusedOptimizer( + config, + self._emb_module, + pg, + ) + self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict( _gen_named_parameters_by_table_fused( emb_module=self._emb_module, @@ -1088,6 +1161,20 @@ def flush(self) -> None: def purge(self) -> None: self._emb_module.reset_cache_states() + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + if self._grad_acc_queue is None: + return super().forward(features) + + with torch.no_grad(): + if hasattr(self._emb_module, 'iter'): + step = self._emb_module.iter.item() + ret = super().forward(features) + if hasattr(self._emb_module, 'iter'): + self._emb_module.iter[0] = step + ret.requires_grad=True + self._grad_acc_queue.push(features, ret) + return ret + class BatchedDenseEmbedding(BaseBatchedEmbedding[torch.Tensor]): def __init__(