Skip to content

Commit 0443278

Browse files
committed
lower precision setting
1 parent 6b5c12a commit 0443278

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/unit/model_parallelism/test_autotp_training.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def testRowParallel(self, tp_size: int, overlap_comm: bool):
216216

217217
torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=1)[groups.get_tensor_model_parallel_rank()]
218218
assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3)
219-
assert torch.allclose(out, torch_out.to(get_accelerator().current_device()), atol=1e-3)
219+
assert torch.allclose(out, torch_out.to(get_accelerator().current_device()), atol=1e-2)
220220

221221
def testColumnParallel(self, tp_size: int, overlap_comm: bool):
222222
skip_on_device()
@@ -269,7 +269,7 @@ def testColumnParallel(self, tp_size: int, overlap_comm: bool):
269269
assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3)
270270
assert torch.allclose(cur_device_out.to(get_accelerator().current_device()).contiguous(),
271271
out.contiguous(),
272-
atol=1e-3)
272+
atol=1e-2)
273273

274274

275275
@pytest.mark.sequential

0 commit comments

Comments
 (0)