-
Notifications
You must be signed in to change notification settings - Fork 4k
Expand file tree
/
Copy pathutils.py
More file actions
30 lines (23 loc) · 871 Bytes
/
utils.py
File metadata and controls
30 lines (23 loc) · 871 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import importlib.util
import torch
def get_device_initial(preferred_device=None):
"""
Determine the appropriate device to use (cuda, hpu, or cpu).
Args:
preferred_device (str): User-preferred device ('cuda', 'hpu', or 'cpu').
Returns:
str: Device string ('cuda', 'hpu', or 'cpu').
"""
# Check for HPU support
if importlib.util.find_spec("habana_frameworks") is not None:
from habana_frameworks.torch.utils.library_loader import load_habana_module
load_habana_module()
if torch.hpu.is_available():
if preferred_device == "hpu" or preferred_device is None:
return "hpu"
# Check for CUDA (GPU support)
if torch.cuda.is_available():
if preferred_device == "cuda" or preferred_device is None:
return "cuda"
# Default to CPU
return "cpu"