Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient accumulation #2653

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 92 additions & 5 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,25 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:

return ssd_tbe_params

enable_gradient_accumulation = False

def set_gradient_accumulation(enable: bool):
global enable_gradient_accumulation
enable_gradient_accumulation = enable

class GradientAccumulationQueue:
def __init__(self) -> None:
self.inputs = []
self.grads = []

def push(self, input, grad):
self.inputs.append(input)
self.grads.append(grad)

def pop(self):
i, g = self.inputs, [x.grad for x in self.grads]
self.inputs, self.grads = [], []
return i, g

class KeyValueEmbeddingFusedOptimizer(FusedOptimizer):
def __init__(
Expand Down Expand Up @@ -625,6 +644,48 @@ def set_optimizer_step(self, step: int) -> None:
self._emb_module.set_optimizer_step(step)


class GradAccEmbeddingFusedOptimizer(EmbeddingFusedOptimizer):
def __init__(self, queue: GradientAccumulationQueue, *args, **kwargs):
super().__init__(*args, **kwargs)
self._queue = queue
self._permute_indices = {}

def step(self, closure: Any = None) -> None:
with torch.cuda.nvtx.range("grad acc optimizer step"):
inputs, grads = self._queue.pop()
if len(inputs) > 0:
for i in range(1, len(inputs)):
assert inputs[i].keys() == inputs[0].keys()
key_cnt, input_cnt = len(inputs[0].keys()), len(inputs)
ids = KeyedJaggedTensor.concat(inputs)
x = ids.values()
grad = torch.cat(grads, dim=0)
del inputs, grads

ids._values = torch.arange(0, x.numel(), dtype=torch.int32, device='cuda')
indices_key = (key_cnt, input_cnt)
if indices_key in self._permute_indices:
permute_indices, permute_indices_tensor = self._permute_indices[indices_key]
else:
permute_indices = []
for i in range(key_cnt):
for j in range(input_cnt):
permute_indices.append(j * key_cnt + i)
permute_indices_tensor = torch.tensor(permute_indices, dtype=torch.int, device=x.device)
self._permute_indices[indices_key] = (permute_indices, permute_indices_tensor)

ids = ids.permute(permute_indices, indices_tensor=permute_indices_tensor)
x = torch.index_select(x, dim=0, index=ids.values())
grad = torch.index_select(grad, dim=0, index=ids.values())

output = self._emb_module(
indices=x.long(),
offsets=ids.offsets().long(),
)
torch.autograd.backward(output, grad_tensors=grad)

super().step(closure)

def _gen_named_parameters_by_table_ssd(
emb_module: SSDTableBatchedEmbeddingBags,
table_name_to_count: Dict[str, int],
Expand Down Expand Up @@ -1034,11 +1095,23 @@ def __init__(
**fused_params,
)
)
self._optim: EmbeddingFusedOptimizer = EmbeddingFusedOptimizer(
config,
self._emb_module,
pg,
)
global enable_gradient_accumulation
if enable_gradient_accumulation:
self._grad_acc_queue = GradientAccumulationQueue()
self._optim: EmbeddingFusedOptimizer = GradAccEmbeddingFusedOptimizer(
self._grad_acc_queue,
config,
self._emb_module,
pg,
)
else:
self._grad_acc_queue = None
self._optim: EmbeddingFusedOptimizer = EmbeddingFusedOptimizer(
config,
self._emb_module,
pg,
)

self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict(
_gen_named_parameters_by_table_fused(
emb_module=self._emb_module,
Expand Down Expand Up @@ -1088,6 +1161,20 @@ def flush(self) -> None:
def purge(self) -> None:
self._emb_module.reset_cache_states()

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
if self._grad_acc_queue is None:
return super().forward(features)

with torch.no_grad():
if hasattr(self._emb_module, 'iter'):
step = self._emb_module.iter.item()
ret = super().forward(features)
if hasattr(self._emb_module, 'iter'):
self._emb_module.iter[0] = step
ret.requires_grad=True
self._grad_acc_queue.push(features, ret)
return ret


class BatchedDenseEmbedding(BaseBatchedEmbedding[torch.Tensor]):
def __init__(
Expand Down