Skip to content

Commit f63eaa4

Browse files
Nicorgijessicazhongeee
authored andcommitted
Fix a bug in set float32 precision (#2271)
1 parent db6b45b commit f63eaa4

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

tests/torchtune/training/test_precision.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ def test_error_bf16_unsupported(self, mock_verify):
5757
get_dtype(torch.bfloat16)
5858

5959
@pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.")
60-
@mock.patch("torchtune.training.precision.is_npu_available", return_value=True)
61-
def test_set_float32_precision(self, mock_npu_available) -> None:
60+
def test_set_float32_precision(self) -> None:
6261
setattr( # noqa: B010
6362
torch.backends, "__allow_nonbracketed_mutation_flag", True
6463
)

torchtune/training/precision.py

+4
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,15 @@ def _set_float32_precision(precision: str = "high") -> None:
3434
precision (str): The setting to determine which datatypes to use for matrix multiplication and convolution operations.
3535
"""
3636
# Not relevant for non-CUDA or non-NPU devices
37+
<<<<<<< HEAD
3738
<<<<<<< HEAD
3839
if not (torch.cuda.is_available() or is_npu_available):
3940
=======
4041
if not torch.cuda.is_available() or not is_npu_available:
4142
>>>>>>> c5f20b96 (Add Ascend NPU as a backend for single device recipes (#2234))
43+
=======
44+
if not (torch.cuda.is_available() or is_npu_available):
45+
>>>>>>> 890deab3 (Fix a bug in set float32 precision (#2271))
4246
return
4347
# set precision for matrix multiplications
4448
torch.set_float32_matmul_precision(precision)

0 commit comments

Comments
 (0)