Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/fabric/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def setup_device(self, device: torch.device) -> None:
"""
if device.type != "cuda":
raise ValueError(f"Device should be CUDA, got {device} instead.")
_check_cuda_matmul_precision(device)
torch.cuda.set_device(device)
_check_cuda_matmul_precision(device) # may initialize CUDA, set_device first

@override
def teardown(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def setup_device(self, device: torch.device) -> None:
"""
if device.type != "cuda":
raise MisconfigurationException(f"Device should be GPU, got {device} instead")
_check_cuda_matmul_precision(device)
torch.cuda.set_device(device)
_check_cuda_matmul_precision(device) # may initialize CUDA, set_device first

@override
def setup(self, trainer: "pl.Trainer") -> None:
Expand Down
Loading