Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions deepxde/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,35 @@
torch.set_default_device("cuda")
else:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
elif torch.backends.mps.is_available():
fallback_device = torch.get_default_device()
torch.set_default_device("mps")

# As of March 2025, the macOS X-based GitHub Actions building environment sees
# the MPS GPU, but cannot access it. So, a try-except workaround is applied.
try:
# A temporary trick to evade the Pytorch optimizer bug on MPS GPUs
# See https://github.com/pytorch/pytorch/issues/149184
torch._dynamo.disable()

# If the Pytorch optimizer bug is fixed and the line above is removed,
# the following code will perform a simple check of the MPS GPU
test_nn = torch.nn.Sequential(
torch.nn.Linear(1, 2),
torch.nn.Tanh(),
)
test_input = torch.randn(3, 1)
test_run = test_nn(test_input)
del test_nn, test_input, test_run
torch.mps.empty_cache()

except Exception as e:
import warnings
warnings.warn(
f'An MPS GPU has been detected, but cannot be used. '
f'Falling back to the CPU.\nThe exception message is:\n {e}'
)
torch.set_default_device(fallback_device)


lib = torch
Expand Down