Skip to content

Commit c8a4746

Browse files
committed
introduce device_context to simplify code.
1 parent ec1757c commit c8a4746

File tree

2 files changed

+65
-43
lines changed

2 files changed

+65
-43
lines changed

unsloth/device_type.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
"DEVICE_COUNT",
2121
"ALLOW_PREQUANTIZED_MODELS",
2222
"ALLOW_BITSANDBYTES",
23+
"DeviceContext",
24+
"device_context",
25+
"clean_gpu_cache",
26+
"get_current_device",
2327
]
2428

2529
import torch
@@ -125,3 +129,56 @@ def get_device_count():
125129
Params4bit
126130
):
127131
ALLOW_PREQUANTIZED_MODELS = False
132+
133+
134+
class DeviceContext:
135+
"""Encapsulates device-specific operations for XPU/HIP/CUDA."""
136+
137+
def __init__(self, device_type: str = DEVICE_TYPE) -> None:
138+
if device_type not in ("cuda", "hip", "xpu"):
139+
raise ValueError(f"Unsloth: Unsupported device type: {device_type}")
140+
self.device_type = device_type
141+
# Cache the torch module for this device
142+
self.torch_module = torch.xpu if device_type == "xpu" else torch.cuda
143+
144+
def get_stats(self) -> tuple[str, str, float]:
145+
"""Return (name, stats_snippet, max_memory_gb)."""
146+
gpu_stats = self.torch_module.get_device_properties(0)
147+
max_mem = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
148+
149+
# Device name
150+
name = gpu_stats.name + ". " if gpu_stats.name else self._get_default_name()
151+
152+
# Toolkit snippet
153+
snippet = self._get_toolkit_snippet(gpu_stats)
154+
155+
return name, snippet, max_mem
156+
157+
def _get_default_name(self) -> str:
158+
"""Get default device name when props.name is empty."""
159+
names = {"xpu": "Intel XPU", "cuda": "NVIDIA GPU", "hip": "AMD GPU"}
160+
return names[self.device_type] + " Device. "
161+
162+
def _get_toolkit_snippet(self, props) -> str:
163+
"""Get toolkit version snippet."""
164+
if self.device_type == "cuda":
165+
return f"CUDA: {props.major}.{props.minor}. CUDA Toolkit: {torch.version.cuda}."
166+
elif self.device_type == "hip":
167+
return f"ROCm Toolkit: {torch.version.hip}."
168+
else: # xpu
169+
return f"Intel Toolkit: {torch.version.xpu}."
170+
171+
172+
# Singleton instance
173+
device_context = DeviceContext()
174+
175+
176+
# Module-level functions for backward compatibility
177+
def clean_gpu_cache() -> None:
178+
"""Clear GPU cache for current device type."""
179+
device_context.torch_module.empty_cache()
180+
181+
182+
def get_current_device() -> int:
183+
"""Get current device index."""
184+
return device_context.torch_module.current_device()

unsloth/models/llama.py

Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@
5252
DEVICE_TYPE_TORCH,
5353
DEVICE_COUNT,
5454
ALLOW_PREQUANTIZED_MODELS,
55+
device_context,
56+
clean_gpu_cache,
57+
get_current_device,
5558
)
5659

5760
transformers_version = Version(transformers_version)
@@ -119,13 +122,6 @@
119122
xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None
120123
)
121124

122-
if DEVICE_TYPE == "xpu":
123-
clean_gpu_cache = torch.xpu.empty_cache
124-
get_current_device = torch.xpu.current_device
125-
else:
126-
clean_gpu_cache = torch.cuda.empty_cache
127-
get_current_device = torch.cuda.current_device
128-
129125

130126
def original_apply_qkv(self, X):
131127
Q = self.q_proj(X)
@@ -2188,43 +2184,12 @@ def from_pretrained(
21882184
model_patcher = FastLlamaModel
21892185
SUPPORTS_BFLOAT16 = is_bfloat16_supported()
21902186

2191-
if DEVICE_TYPE == "cuda":
2192-
gpu_stats = torch.cuda.get_device_properties(0)
2193-
gpu_stats_name = (
2194-
gpu_stats.name + ". " if gpu_stats.name != "" else "NVIDIA GPU Device. "
2195-
)
2196-
gpu_version = torch.version.cuda
2197-
gpu_stats_snippet = f"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}."
2198-
try:
2199-
vllm_version = f" vLLM: {importlib_version('vllm')}."
2200-
except:
2201-
vllm_version = ""
2202-
elif DEVICE_TYPE == "hip":
2203-
gpu_stats = torch.cuda.get_device_properties(0)
2204-
gpu_stats_name = (
2205-
gpu_stats.name + ". " if gpu_stats.name != "" else "AMD GPU Device. "
2206-
)
2207-
gpu_version = torch.version.hip
2208-
gpu_stats_snippet = f"ROCm Toolkit: {gpu_version}."
2209-
try:
2210-
vllm_version = f" vLLM: {importlib_version('vllm')}."
2211-
except:
2212-
vllm_version = ""
2213-
elif DEVICE_TYPE == "xpu":
2214-
gpu_stats = torch.xpu.get_device_properties(0)
2215-
gpu_stats_name = (
2216-
gpu_stats.name + ". " if gpu_stats.name != "" else "Intel XPU Device. "
2217-
)
2218-
gpu_version = torch.version.xpu
2219-
gpu_stats_snippet = f"Intel Toolkit: {gpu_version}."
2220-
try:
2221-
vllm_version = f" vLLM: {importlib_version('vllm')}."
2222-
except:
2223-
vllm_version = ""
2224-
else:
2225-
raise ValueError(f"Unsloth: Unsupported device type: {DEVICE_TYPE}")
2187+
gpu_stats_name, gpu_stats_snippet, max_memory = _device_ctx.get_stats()
22262188

2227-
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
2189+
try:
2190+
vllm_version = f" vLLM: {importlib_version('vllm')}."
2191+
except:
2192+
vllm_version = ""
22282193

22292194
statistics = (
22302195
f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\n"

0 commit comments

Comments
 (0)