We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 8b24140 commit 8e6ca60Copy full SHA for 8e6ca60
test/spmd/test_train_spmd_linear_model.py
@@ -74,7 +74,7 @@ def test_gradient_accumulation_matches(self):
74
# Verify that the model losses are not zero, and that the runs match.
75
assert all(loss != 0 for loss in baseline_grad_acc_losses)
76
assert all(
77
- torch.allclose(baseline_loss, checkpointing_loss)
+ torch.allclose(baseline_loss, checkpointing_loss, rtol=1e-4, atol=1e-8)
78
for baseline_loss, checkpointing_loss in zip(baseline_grad_acc_losses,
79
loop_grad_acc_losses))
80
0 commit comments