diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 562dcfc9cd744..4950bed292dc1 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -34,8 +34,8 @@ def setup_device(self, device: torch.device) -> None: """ if device.type != "cuda": raise ValueError(f"Device should be CUDA, got {device} instead.") - _check_cuda_matmul_precision(device) torch.cuda.set_device(device) + _check_cuda_matmul_precision(device) # may initialize CUDA, set_device first @override def teardown(self) -> None: diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index 63a0d8adba8ea..008d47fe256ca 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -43,8 +43,8 @@ def setup_device(self, device: torch.device) -> None: """ if device.type != "cuda": raise MisconfigurationException(f"Device should be GPU, got {device} instead") - _check_cuda_matmul_precision(device) torch.cuda.set_device(device) + _check_cuda_matmul_precision(device) # may initialize CUDA, set_device first @override def setup(self, trainer: "pl.Trainer") -> None: