Skip to content

Commit 530217f

Browse files
authored
Add Apple MPS GPU support
There is a Pytorch bug that occurs when an optimizer is created for a network which is on an MPS GPU (pytorch/pytorch#149184), thus a trick to avoid it is needed.
1 parent da2ec60 commit 530217f

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

deepxde/backend/pytorch/tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
torch.set_default_device("cuda")
2121
else:
2222
torch.set_default_tensor_type(torch.cuda.FloatTensor)
23+
elif torch.backends.mps.is_available():
24+
torch.set_default_device("mps")
25+
torch._dynamo.disable() # A temporary trick to evade the Pytorch MPS bug (https://github.com/pytorch/pytorch/issues/149184)
2326

2427

2528
lib = torch

0 commit comments

Comments
 (0)