Skip to content

Commit 39dd795

Browse files
authored
Fix the retention for existing gradients in the Grad Acc API (#8658)
1 parent e583c2c commit 39dd795

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

torch_xla/experimental/gradient_accumulation.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@ def gradient_accumulation(
3939
4040
Notes:
4141
42-
The model tracing will happen entirely within the loop. Hence, it is
42+
* The model tracing will happen entirely within the loop. Hence, it is
4343
assumed that `train_step` is purposefully encapsulated inside of the
4444
loop. Hence, it is not recommended to have any operation involving the
4545
model parameters outside of `train_step`.
46+
* Note that zeroing the gradients to zero instead of None, (e.g.
47+
`.zero_grad(set_to_none=False)) will avoid the device transfer of the
48+
initial gradients in every call.
4649
4750
Args:
4851
train_step: Training function that takes iterable tensors and carried
@@ -380,7 +383,7 @@ def body_fn(iteri: torch.Tensor, _: torch.Tensor,
380383
for param in model_parameters:
381384
if not param.requires_grad:
382385
continue
383-
if param.grad:
386+
if param.grad is not None:
384387
grad = param.grad
385388
else:
386389
grad = torch.zeros(param.size()).to(param.device).requires_grad_(False)

0 commit comments

Comments
 (0)