Skip to content

Commit dd5313a

Browse files
authored
Fix for PyTorch MPS support
Some tensor operations occur during "torch._dynamo.disable()", which may lead to crash if the default device has already been set to MPS.
1 parent b79d2fd commit dd5313a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

deepxde/backend/pytorch/tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@
2222
torch.set_default_tensor_type(torch.cuda.FloatTensor)
2323
elif torch.backends.mps.is_available():
2424
fallback_device = torch.get_default_device()
25-
torch.set_default_device("mps")
2625

2726
# As of March 2025, the macOS X-based GitHub Actions building environment sees
2827
# the MPS GPU, but cannot access it. So, a try-except workaround is applied.
2928
try:
3029
# A temporary trick to evade the Pytorch optimizer bug on MPS GPUs
3130
# See https://github.com/pytorch/pytorch/issues/149184
31+
# As for May 2025, it must go before the default device change
3232
torch._dynamo.disable()
3333

34+
torch.set_default_device("mps")
35+
3436
# If the Pytorch optimizer bug is fixed and the line above is removed,
3537
# the following code will perform a simple check of the MPS GPU
3638
test_nn = torch.nn.Sequential(

0 commit comments

Comments
 (0)