Skip to content

Commit 28cb8f0

Browse files
authored
Backend Pytorch: Add Apple MPS GPU support (#1973)
1 parent 67ec746 commit 28cb8f0

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

deepxde/backend/pytorch/tensor.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,35 @@
2020
torch.set_default_device("cuda")
2121
else:
2222
torch.set_default_tensor_type(torch.cuda.FloatTensor)
23+
elif torch.backends.mps.is_available():
24+
fallback_device = torch.get_default_device()
25+
torch.set_default_device("mps")
26+
27+
# As of March 2025, the macOS X-based GitHub Actions building environment sees
28+
# the MPS GPU, but cannot access it. So, a try-except workaround is applied.
29+
try:
30+
# A temporary trick to evade the Pytorch optimizer bug on MPS GPUs
31+
# See https://github.com/pytorch/pytorch/issues/149184
32+
torch._dynamo.disable()
33+
34+
# If the Pytorch optimizer bug is fixed and the line above is removed,
35+
# the following code will perform a simple check of the MPS GPU
36+
test_nn = torch.nn.Sequential(
37+
torch.nn.Linear(1, 2),
38+
torch.nn.Tanh(),
39+
)
40+
test_input = torch.randn(3, 1)
41+
test_run = test_nn(test_input)
42+
del test_nn, test_input, test_run
43+
torch.mps.empty_cache()
44+
45+
except Exception as e:
46+
import warnings
47+
warnings.warn(
48+
f'An MPS GPU has been detected, but cannot be used. '
49+
f'Falling back to the CPU.\nThe exception message is:\n {e}'
50+
)
51+
torch.set_default_device(fallback_device)
2352

2453

2554
lib = torch

0 commit comments

Comments
 (0)