Skip to content

Commit 1b15204

Browse files
authored
Add manual check of access to MPS GPU
If the Pytorch optimizer bug is fixed and the line `torch._dynamo.disalbe()` (which raises an exception now) is removed, the added code will perform a simple check whether the MPS GPU can be used.
1 parent b7d930e commit 1b15204

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

deepxde/backend/pytorch/tensor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,18 @@
3030
# A temporary trick to evade the Pytorch optimizer bug on MPS
3131
# See https://github.com/pytorch/pytorch/issues/149184
3232
torch._dynamo.disable()
33-
33+
34+
# If the Pytorch optimizer bug is fixed and the line above is removed,
35+
# the following code will perform a simple check of the MPS GPU
36+
test_nn = torch.nn.Sequential(
37+
torch.nn.Linear(1, 2),
38+
torch.nn.Tanh(),
39+
)
40+
test_input = torch.randn(3, 1)
41+
test_run = test_nn(test_input)
42+
del test_nn, test_input, test_run
43+
torch.mps.empty_cache()
44+
3445
except Exception as e:
3546
import warnings
3647
warnings.warn(

0 commit comments

Comments
 (0)