|
20 | 20 | "DEVICE_COUNT", |
21 | 21 | "ALLOW_PREQUANTIZED_MODELS", |
22 | 22 | "ALLOW_BITSANDBYTES", |
| 23 | + "DeviceContext", |
| 24 | + "device_context", |
| 25 | + "clean_gpu_cache", |
| 26 | + "get_current_device", |
23 | 27 | ] |
24 | 28 |
|
25 | 29 | import torch |
@@ -125,3 +129,56 @@ def get_device_count(): |
125 | 129 | Params4bit |
126 | 130 | ): |
127 | 131 | ALLOW_PREQUANTIZED_MODELS = False |
| 132 | + |
| 133 | + |
| 134 | +class DeviceContext: |
| 135 | + """Encapsulates device-specific operations for XPU/HIP/CUDA.""" |
| 136 | + |
| 137 | + def __init__(self, device_type: str = DEVICE_TYPE) -> None: |
| 138 | + if device_type not in ("cuda", "hip", "xpu"): |
| 139 | + raise ValueError(f"Unsloth: Unsupported device type: {device_type}") |
| 140 | + self.device_type = device_type |
| 141 | + # Cache the torch module for this device |
| 142 | + self.torch_module = torch.xpu if device_type == "xpu" else torch.cuda |
| 143 | + |
| 144 | + def get_stats(self) -> tuple[str, str, float]: |
| 145 | + """Return (name, stats_snippet, max_memory_gb).""" |
| 146 | + gpu_stats = self.torch_module.get_device_properties(0) |
| 147 | + max_mem = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) |
| 148 | + |
| 149 | + # Device name |
| 150 | + name = gpu_stats.name + ". " if gpu_stats.name else self._get_default_name() |
| 151 | + |
| 152 | + # Toolkit snippet |
| 153 | + snippet = self._get_toolkit_snippet(gpu_stats) |
| 154 | + |
| 155 | + return name, snippet, max_mem |
| 156 | + |
| 157 | + def _get_default_name(self) -> str: |
| 158 | + """Get default device name when props.name is empty.""" |
| 159 | + names = {"xpu": "Intel XPU", "cuda": "NVIDIA GPU", "hip": "AMD GPU"} |
| 160 | + return names[self.device_type] + " Device. " |
| 161 | + |
| 162 | + def _get_toolkit_snippet(self, props) -> str: |
| 163 | + """Get toolkit version snippet.""" |
| 164 | + if self.device_type == "cuda": |
| 165 | + return f"CUDA: {props.major}.{props.minor}. CUDA Toolkit: {torch.version.cuda}." |
| 166 | + elif self.device_type == "hip": |
| 167 | + return f"ROCm Toolkit: {torch.version.hip}." |
| 168 | + else: # xpu |
| 169 | + return f"Intel Toolkit: {torch.version.xpu}." |
| 170 | + |
| 171 | + |
| 172 | +# Singleton instance |
| 173 | +device_context = DeviceContext() |
| 174 | + |
| 175 | + |
| 176 | +# Module-level functions for backward compatibility |
| 177 | +def clean_gpu_cache() -> None: |
| 178 | + """Clear GPU cache for current device type.""" |
| 179 | + device_context.torch_module.empty_cache() |
| 180 | + |
| 181 | + |
| 182 | +def get_current_device() -> int: |
| 183 | + """Get current device index.""" |
| 184 | + return device_context.torch_module.current_device() |
0 commit comments