1616
1717
1818def is_torch_npu_available () -> bool :
19- """Check the availability of NPU"""
19+ """Check if Ascend NPU is available for PyTorch operations.
20+
21+ Attempts to detect NPU availability by checking for the torch.npu module
22+ and its is_available() function.
23+
24+ Returns:
25+ bool: True if NPU is available, False otherwise.
26+ """
2027 try :
2128 if hasattr (torch , "npu" ) and callable (getattr (torch .npu , "is_available" , None )):
2229 return torch .npu .is_available ()
@@ -30,18 +37,26 @@ def is_torch_npu_available() -> bool:
3037
3138
3239def get_visible_devices_keyword () -> str :
33- """Function that gets visible devices keyword name.
40+ """Get the environment variable name for visible device selection.
41+
42+ Returns the appropriate environment variable name based on the available
43+ accelerator type (CUDA or Ascend NPU).
44+
3445 Returns:
35- 'CUDA_VISIBLE_DEVICES' or `ASCEND_RT_VISIBLE_DEVICES`
46+ str: 'CUDA_VISIBLE_DEVICES' if CUDA is available,
47+ 'ASCEND_RT_VISIBLE_DEVICES' otherwise.
3648 """
3749 return "CUDA_VISIBLE_DEVICES" if is_cuda_available else "ASCEND_RT_VISIBLE_DEVICES"
3850
3951
4052def get_device_name () -> str :
41- """Function that gets the torch.device based on the current machine.
42- This currently only supports CPU, CUDA, NPU.
53+ """Get the device type string based on available accelerators.
54+
55+ Detects the available accelerator and returns the corresponding PyTorch
56+ device type string. Currently supports CUDA, Ascend NPU, and CPU.
57+
4358 Returns:
44- device
59+ str: Device type string ('cuda', 'npu', or 'cpu').
4560 """
4661 if is_cuda_available :
4762 device = "cuda"
@@ -52,10 +67,15 @@ def get_device_name() -> str:
5267 return device
5368
5469
55- def get_torch_device () -> any :
56- """Return the corresponding torch attribute based on the device type string.
70+ def get_torch_device ():
71+ """Get the PyTorch device module for the current accelerator.
72+
73+ Returns the torch device namespace (e.g., torch.cuda, torch.npu) based on
74+ the detected accelerator type. Falls back to torch.cuda if the namespace
75+ is not found.
76+
5777 Returns:
58- module: The corresponding torch device namespace, or torch.cuda if not found .
78+ module: The PyTorch device module ( torch.cuda, torch.npu, etc.) .
5979 """
6080 device_name = get_device_name ()
6181 try :
@@ -66,17 +86,22 @@ def get_torch_device() -> any:
6686
6787
6888def get_device_id () -> int :
69- """Return current device id based on the device type.
89+ """Get the index of the current accelerator device.
90+
7091 Returns:
71- device index
92+ int: The current device index (e.g., 0 for 'cuda:0').
7293 """
7394 return get_torch_device ().current_device ()
7495
7596
7697def get_nccl_backend () -> str :
77- """Return nccl backend type based on the device type.
98+ """Get the distributed communication backend based on device type.
99+
100+ Returns the appropriate collective communication backend for the
101+ detected accelerator (HCCL for Ascend NPU, NCCL for CUDA).
102+
78103 Returns:
79- nccl backend type string .
104+ str: Backend name ('hccl' for NPU, 'nccl' for CUDA/default) .
80105 """
81106 if is_npu_available :
82107 return "hccl"
@@ -86,15 +111,32 @@ def get_nccl_backend() -> str:
86111
87112
88113def set_expandable_segments (enable : bool ) -> None :
89- """Enable or disable expandable segments for cuda.
114+ """Configure CUDA memory allocator expandable segments setting.
115+
116+ Expandable segments can help avoid out-of-memory (OOM) errors by allowing
117+ the memory allocator to expand existing memory segments rather than
118+ allocating new ones.
119+
90120 Args:
91- enable (bool): Whether to enable expandable segments. Used to avoid OOM.
121+ enable: If True, enable expandable segments. If False, disable them.
122+
123+ Note:
124+ This function only has an effect when CUDA is available.
92125 """
93126 if is_cuda_available :
94127 torch .cuda .memory ._set_allocator_settings (f"expandable_segments:{ enable } " )
95128
96129
97- def auto_set_ascend_device_name (config ):
130+ def auto_set_ascend_device_name (config ) -> None :
131+ """Automatically configure device name for Ascend NPU environments.
132+
133+ If running on an Ascend NPU system, this function ensures the trainer
134+ device configuration is set to 'npu'. Logs a warning if the config
135+ was set to a different device type.
136+
137+ Args:
138+ config: Configuration object with trainer.device attribute.
139+ """
98140 if config and config .trainer and config .trainer .device :
99141 if is_torch_npu_available ():
100142 if config .trainer .device != "npu" :
@@ -106,7 +148,16 @@ def auto_set_ascend_device_name(config):
106148 config .trainer .device = "npu"
107149
108150
109- def get_device_capability (device_id : int = 0 ) -> tuple [int , int ]:
151+ def get_device_capability (device_id : int = 0 ) -> tuple [int | None , int | None ]:
152+ """Get the compute capability of a CUDA device.
153+
154+ Args:
155+ device_id: The CUDA device index to query. Defaults to 0.
156+
157+ Returns:
158+ tuple: A tuple of (major, minor) compute capability version,
159+ or (None, None) if CUDA is not available.
160+ """
110161 major , minor = None , None
111162 if is_cuda_available :
112163 major , minor = torch .cuda .get_device_capability (device_id )
0 commit comments