Skip to content

Commit 74f3bc0

Browse files
author
sambit-giri
committed
ViteBetti with torch can use Apple's AMD GPUs
1 parent 7dd3ca5 commit 74f3bc0

2 files changed

Lines changed: 7 additions & 2 deletions

File tree

src/tools21cm/ViteBetti.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ def jit(func, *args, **kwargs):
3434
try:
3535
import torch
3636
torch_available = True
37-
torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37+
if torch.cuda.is_available():
38+
torch_device = torch.device('cuda')
39+
elif torch.backends.mps.is_available():
40+
torch_device = torch.device('mps')
41+
else:
42+
torch_device = torch.device('cpu')
3843
except ImportError:
3944
torch_available = False
4045
torch = None

src/tools21cm/topology.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def EulerCharacteristic(data, thres=0.5, neighbors=6, speed_up='cython', verbose
2121
speed_up: str
2222
Method used to speed up calculation (Default: cython).
2323
Options are cython, numba, torch and numpy.
24-
The caclulation with torch uses GPUs if available on the device.
24+
The calculation with torch uses GPUs if available on the device.
2525
verbose: bool
2626
If True, verbose is printed.
2727

0 commit comments

Comments
 (0)