Skip to content

Commit 8e6ca60

Browse files
authored
Fix grad accumulation tests on TPUv5/v6 (#8628)
1 parent 8b24140 commit 8e6ca60

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

test/spmd/test_train_spmd_linear_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_gradient_accumulation_matches(self):
7474
# Verify that the model losses are not zero, and that the runs match.
7575
assert all(loss != 0 for loss in baseline_grad_acc_losses)
7676
assert all(
77-
torch.allclose(baseline_loss, checkpointing_loss)
77+
torch.allclose(baseline_loss, checkpointing_loss, rtol=1e-4, atol=1e-8)
7878
for baseline_loss, checkpointing_loss in zip(baseline_grad_acc_losses,
7979
loop_grad_acc_losses))
8080

0 commit comments

Comments
 (0)