File tree Expand file tree Collapse file tree 1 file changed +29
-0
lines changed Expand file tree Collapse file tree 1 file changed +29
-0
lines changed Original file line number Diff line number Diff line change 2020 torch .set_default_device ("cuda" )
2121 else :
2222 torch .set_default_tensor_type (torch .cuda .FloatTensor )
23+ elif torch .backends .mps .is_available ():
24+ fallback_device = torch .get_default_device ()
25+ torch .set_default_device ("mps" )
26+
27+ # As of March 2025, the macOS X-based GitHub Actions building environment sees
28+ # the MPS GPU, but cannot access it. So, a try-except workaround is applied.
29+ try :
30+ # A temporary trick to evade the Pytorch optimizer bug on MPS GPUs
31+ # See https://github.com/pytorch/pytorch/issues/149184
32+ torch ._dynamo .disable ()
33+
34+ # If the Pytorch optimizer bug is fixed and the line above is removed,
35+ # the following code will perform a simple check of the MPS GPU
36+ test_nn = torch .nn .Sequential (
37+ torch .nn .Linear (1 , 2 ),
38+ torch .nn .Tanh (),
39+ )
40+ test_input = torch .randn (3 , 1 )
41+ test_run = test_nn (test_input )
42+ del test_nn , test_input , test_run
43+ torch .mps .empty_cache ()
44+
45+ except Exception as e :
46+ import warnings
47+ warnings .warn (
48+ f'An MPS GPU has been detected, but cannot be used. '
49+ f'Falling back to the CPU.\n The exception message is:\n { e } '
50+ )
51+ torch .set_default_device (fallback_device )
2352
2453
2554lib = torch
You can’t perform that action at this time.
0 commit comments