Skip to content

Commit 33b02d4

Browse files
authored
Move inputs to right devices. (#2919)
* Move tensors to right devices * fix multi gpu for non mistral models * multi GPU RoPE for gemma2 * Finish up multi GPU inference * Make multiGPU rope a list * Remove unnecessary transfer to CPU * Remove unnecessary move to CPU * Donot move inputs to device yet will be handled separately in another PR * Move inputs to appropriate decoder device * Make device count global variable * Cleanup RoPE device code * Fixup num_gpu to device count * Cleanup device counts * Use device index for RoPE get_cache * Donot typecast * Use tuple instead of list for tensors. Use device index directly * fixup move to device logic
1 parent 12c78de commit 33b02d4

File tree

12 files changed

+337
-216
lines changed

12 files changed

+337
-216
lines changed

unsloth/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
already_imported = [mod for mod in critical_modules if mod in sys.modules]
2323

2424
# This check is critical because Unsloth optimizes these libraries by modifying
25-
# their code at import time. If they're imported first, the original (slower,
25+
# their code at import time. If they're imported first, the original (slower,
2626
# more memory-intensive) implementations will be used instead of Unsloth's
2727
# optimized versions, potentially causing OOM errors or slower training.
2828

@@ -73,6 +73,17 @@ def get_device_type():
7373
pass
7474
DEVICE_TYPE : str = get_device_type()
7575

76+
def get_device_count():
77+
if DEVICE_TYPE == "cuda":
78+
return torch.cuda.device_count()
79+
elif DEVICE_TYPE == "xpu":
80+
return torch.xpu.device_count()
81+
else:
82+
return 0
83+
pass
84+
85+
DEVICE_COUNT : int = get_device_count()
86+
7687
# Reduce VRAM usage by reducing fragmentation
7788
# And optimize pinning of memory
7889
if DEVICE_TYPE == "cuda" and os.environ.get("UNSLOTH_VLLM_STANDBY", "0")=="0":
@@ -237,4 +248,4 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
237248
from .trainer import *
238249

239250
# Patch TRL trainers for backwards compatibility
240-
_patch_trl_trainer()
251+
_patch_trl_trainer()

unsloth/kernels/utils.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
next_power_of_2 = triton.next_power_of_2
1919
import functools
2020
from typing import Optional
21-
from unsloth import DEVICE_TYPE
21+
from unsloth import DEVICE_TYPE, DEVICE_COUNT
2222

2323
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
2424
import torch
@@ -89,18 +89,19 @@ def get_ptr(x: Optional[torch.Tensor]):
8989
get_ptr = bnb.functional.get_ptr
9090

9191

92-
if DEVICE_TYPE == "cuda" and torch.cuda.device_count() > 1:
93-
torch_gpu_device = torch.cuda.device
94-
elif DEVICE_TYPE == "xpu" and torch.xpu.device_count() > 1:
95-
torch_gpu_device = torch.xpu.device
92+
if DEVICE_COUNT > 1:
93+
if DEVICE_TYPE == "cuda":
94+
torch_gpu_device = torch.cuda.device
95+
elif DEVICE_TYPE == "xpu":
96+
torch_gpu_device = torch.xpu.device
9697
else:
9798
from contextlib import nullcontext
9899
def torch_gpu_device(device): return nullcontext()
99100
pass
100101

101102
# INTEL GPU Specific Logic
102103
if DEVICE_TYPE == "xpu":
103-
_gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
104+
_gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
104105
# NVIDIA GPU Default Logic
105106
else:
106107
_gpu_getCurrentRawStream = torch._C._cuda_getCurrentRawStream
@@ -121,20 +122,20 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
121122
if DEVICE_TYPE == "xpu":
122123
_XPU_STREAMS = {
123124
(index := torch.xpu.device(i).idx) : ctypes.c_void_p(torch._C._xpu_getCurrentRawStream(index))
124-
for i in range(torch.xpu.device_count())
125+
for i in range(DEVICE_COUNT)
125126
}
126127
XPU_STREAMS = [None] * (max(_XPU_STREAMS.keys()) + 1)
127128
WEIGHT_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
128129
ABSMAX_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
129-
for k, v in _XPU_STREAMS.items():
130+
for k, v in _XPU_STREAMS.items():
130131
XPU_STREAMS[k] = v
131132
XPU_STREAMS = tuple(XPU_STREAMS)
132133
del _XPU_STREAMS
133134
else:
134135
# NVIDIA GPU Default Logic
135136
_CUDA_STREAMS = {
136137
(index := torch.cuda.device(i).idx) : ctypes.c_void_p(torch._C._cuda_getCurrentRawStream(index))
137-
for i in range(torch.cuda.device_count())
138+
for i in range(DEVICE_COUNT)
138139
}
139140
CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
140141
WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
@@ -152,16 +153,16 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
152153
# TODO: After adding XPU BNB support, this function should be implemented
153154
def cdequantize_blockwise_fp32(*args, **kwargs):
154155
raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_fp32 should not be called now.")
155-
156+
156157
def cdequantize_blockwise_fp16_nf4(*args, **kwargs):
157158
raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_fp16_nf4 should not be called now.")
158-
159+
159160
def cdequantize_blockwise_bf16_nf4(*args, **kwargs):
160161
raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_bf16_nf4 should not be called now.")
161-
162+
162163
def cgemm_4bit_inference_naive_fp16(*args, **kwargs):
163164
raise RuntimeError("XPU BNB support is not implemented yet. cgemm_4bit_inference_naive_fp16 should not be called now.")
164-
165+
165166
def cgemm_4bit_inference_naive_bf16(*args, **kwargs):
166167
raise RuntimeError("XPU BNB support is not implemented yet. cgemm_4bit_inference_naive_bf16 should not be called now.")
167168
else:
@@ -193,7 +194,7 @@ def get_lora_parameters(proj):
193194
adapter = getattr(proj, "active_adapters", None)
194195
if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
195196
adapter = adapter[0]
196-
197+
197198
return (
198199
W,
199200
getattr(W, "quant_state", None),
@@ -232,7 +233,7 @@ def get_lora_parameters_bias(proj):
232233
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
233234
@torch.inference_mode
234235
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
235-
# TODO: After adding XPU BNB support, check this function
236+
# TODO: After adding XPU BNB support, check this function
236237
if quant_state is None: return W
237238
if type(quant_state) is not list:
238239
# New quant_state as a class
@@ -535,7 +536,7 @@ def fast_gemv(X, W, quant_state, out = None):
535536
device = W.device
536537
device_index = device.index
537538
CUDA_STREAM = CUDA_STREAMS[device_index]
538-
539+
539540
# assert(dtype == X.dtype)
540541
bout = shape[0]
541542

@@ -669,7 +670,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
669670
lora_A._fast_lora = lora_A.to(dtype)
670671
lora_B._fast_lora = lora_B.to(dtype)
671672
pass
672-
673+
673674
if bsz == 1:
674675
out = out.view(out_dim)
675676
temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
@@ -709,6 +710,6 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
709710
out.addmm_(XA, B.to(dtype), alpha = s)
710711
# out += (X @ A.to(dtype)) @ (s * B.to(dtype))
711712
pass
712-
713+
713714
return out.view(batch, seq_len, -1) if reshape else out
714715
pass

unsloth/models/_utils.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
import re
7878
import warnings, subprocess, re, inspect, psutil, os, math
7979
from unsloth_zoo.utils import Version
80-
from unsloth_zoo import DEVICE_TYPE
80+
from unsloth import DEVICE_TYPE, DEVICE_COUNT
8181

8282
from unsloth_zoo.tokenizer_utils import (
8383
patch_tokenizer as _patch_tokenizer,
@@ -142,12 +142,6 @@
142142
import logging
143143
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)
144144

145-
def get_device_num():
146-
if DEVICE_TYPE == "xpu":
147-
return torch.xpu.device_count()
148-
else:
149-
return torch.cuda.device_count()
150-
151145
# Ignore logging messages
152146
class HideLoggingMessage(logging.Filter):
153147
__slots__ = "text",
@@ -746,8 +740,7 @@ def get_statistics():
746740
pass
747741
pass
748742
try:
749-
devices = get_device_num()
750-
_get_statistics(f"{devices if devices <= 8 else 9}")
743+
_get_statistics(f"{DEVICE_COUNT if DEVICE_COUNT <= 8 else 9}")
751744
except:
752745
pass
753746
if disabled: enable_progress_bars()
@@ -773,14 +766,44 @@ def get_statistics():
773766
)
774767
exec(BitsAndBytesConfig__init__, globals())
775768

776-
if get_device_num() == 1:
769+
if DEVICE_COUNT == 1:
777770
from accelerate.utils.dataclasses import DistributedType
778771
def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO
779772
import accelerate.state
780773
accelerate.state.PartialState._prepare_backend = _prepare_backend
781774
accelerate.accelerator.Accelerator.distributed_type = lambda *args, **kwargs: DistributedType.NO
782775
pass
783776

777+
# to move multiple tensors to the same device
778+
def move_to_device(target_device, *tensors):
779+
"""
780+
Move multiple tensors to target device if they're not already there.
781+
782+
Args:
783+
target_device: The target device to move tensors to
784+
*tensors: Variable number of tensors to potentially move
785+
786+
Returns:
787+
tuple: The tensors on the target device (same objects if already on device, new if moved)
788+
"""
789+
if isinstance(target_device, int):
790+
target_device = torch.device(target_device)
791+
elif isinstance(target_device, str):
792+
# if string we expect it to be a device name like "cuda:0"
793+
target_device = torch.device(target_device)
794+
elif isinstance(target_device, torch.device):
795+
pass
796+
else:
797+
raise ValueError(f"Invalid target device: {target_device}")
798+
pass
799+
moved_tensors = []
800+
for tensor in tensors:
801+
if tensor.device != target_device:
802+
moved_tensors.append(tensor.to(target_device))
803+
else:
804+
moved_tensors.append(tensor)
805+
return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0]
806+
784807
import transformers.utils.quantization_config
785808
transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__
786809
# =============================================

unsloth/models/cohere.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def CohereAttention_fast_forward(
7878
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
7979
*args, **kwargs,
8080
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
81-
81+
8282
# Clear inference
8383
if hasattr(self, "paged_attention"):
8484
del self.paged_attention_K
@@ -254,6 +254,7 @@ def CohereAttention_fast_forward_inference(
254254
do_prefill = False,
255255
attention_mask = None,
256256
):
257+
257258
Xn = hidden_states
258259
bsz, _, hd = hidden_states.size()
259260
K1, V1 = past_key_value
@@ -281,14 +282,14 @@ def CohereAttention_fast_forward_inference(
281282
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
282283
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
283284
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
284-
285+
285286
# Mistral Nemo 12b has weird dimensions
286287
if attention_size != hidden_size:
287288
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
288289
else:
289290
self.temp_O = self.temp_QA[1][:,:,:hidden_size]
290291
pass
291-
292+
292293
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
293294
self.scalar = 1.0 / math_sqrt(self.head_dim)
294295
self.half_head_dim = head_dim // 2
@@ -320,7 +321,7 @@ def CohereAttention_fast_forward_inference(
320321

321322
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
322323
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
323-
cos, sin = self.rotary_emb.get_cached(kv_seq_len)
324+
cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index)
324325
cos = cos[position_ids].unsqueeze(1)
325326
sin = sin[position_ids].unsqueeze(1)
326327
h = self.half_head_dim
@@ -338,7 +339,7 @@ def CohereAttention_fast_forward_inference(
338339
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
339340
Kn *= cos
340341
Kn.addcmul_(RH_K, sin)
341-
342+
342343
# New KV cache
343344
# Kn = torch.cat([K1, Kn], dim = 2)
344345
# Vn = torch.cat([V1, Vn], dim = 2)
@@ -397,7 +398,7 @@ def CohereModel_fast_forward_inference(
397398
position_ids,
398399
attention_mask = None,
399400
):
400-
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
401+
out_weights = tuple(torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT))
401402
input_ids = input_ids[:,:self.max_seq_length]
402403
hidden_states = self.model.embed_tokens(input_ids)
403404
hidden_states = hidden_states.to(self.config.torch_dtype)
@@ -417,8 +418,12 @@ def CohereModel_fast_forward_inference(
417418

418419
next_decoder_cache = []
419420
for idx, decoder_layer in enumerate(self.model.layers):
421+
device_index = getattr(decoder_layer, "_per_layer_device_index", 0)
422+
hidden_states, position_ids = move_to_device(
423+
device_index, hidden_states, position_ids
424+
)
420425
residual = hidden_states
421-
hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight)
426+
hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weights[device_index])
422427
hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference(
423428
decoder_layer.self_attn,
424429
hidden_states = hidden_states,
@@ -435,7 +440,7 @@ def CohereModel_fast_forward_inference(
435440

436441
next_decoder_cache.append(present_key_value)
437442
pass
438-
hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weight)
443+
hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weights[device_index])
439444

440445
return BaseModelOutputWithPast(
441446
last_hidden_state = hidden_states,
@@ -468,7 +473,7 @@ def pre_patch():
468473
CohereForCausalLM .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference)
469474
PeftModelForCausalLM .forward = PeftModel_fast_forward
470475
fix_prepare_inputs_for_generation(CohereForCausalLM)
471-
476+
472477
import transformers.models.cohere.modeling_cohere
473478
transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding
474479
return

unsloth/models/falcon_h1.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def FalconH1Attention_fast_forward(
6969
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
7070
*args, **kwargs,
7171
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
72-
72+
7373
# Clear inference
7474
if hasattr(self, "paged_attention"):
7575
del self.paged_attention_K
@@ -110,12 +110,13 @@ def FalconH1Attention_fast_forward(
110110
# Extend RoPE dynamically to fit in VRA
111111
rotary_emb = self.rotary_emb
112112
rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
113+
device_index = Q.device.index
113114

114115
if position_ids is None:
115116
# Useful for LongRoPE
116-
cos, sin = rotary_emb.get_cached(kv_seq_len)
117+
cos, sin = rotary_emb.get_cached(kv_seq_len, device_index)
117118
else:
118-
cos, sin = rotary_emb(V, seq_len = kv_seq_len)
119+
cos, sin = rotary_emb.get_cached(kv_seq_len, device_index)
119120
Q, K = fast_rope_embedding(Q, K, cos, sin)
120121

121122
if past_key_value is not None:
@@ -245,14 +246,14 @@ def FalconH1Attention_fast_forward_inference(
245246
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device)
246247
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device)
247248
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
248-
249+
249250
# Mistral Nemo 12b has weird dimensions
250251
if attention_size != hidden_size:
251252
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
252253
else:
253254
self.temp_O = self.temp_QA[1][:,:,:hidden_size]
254255
pass
255-
256+
256257
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device)
257258
self.scalar = 1.0 / math_sqrt(self.head_dim)
258259
self.half_head_dim = head_dim // 2
@@ -280,7 +281,7 @@ def FalconH1Attention_fast_forward_inference(
280281
# Need to do it prior 2 steps before hitting full on short KV cache
281282
# or else error
282283
self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2)
283-
cos, sin = self.rotary_emb.get_cached(kv_seq_len)
284+
cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index)
284285
cos = cos[position_ids].unsqueeze(1)
285286
sin = sin[position_ids].unsqueeze(1)
286287
h = self.half_head_dim
@@ -298,7 +299,7 @@ def FalconH1Attention_fast_forward_inference(
298299
RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
299300
Kn *= cos
300301
Kn.addcmul_(RH_K, sin)
301-
302+
302303
# New KV cache
303304
# Kn = torch.cat([K1, Kn], dim = 2)
304305
# Vn = torch.cat([V1, Vn], dim = 2)
@@ -580,7 +581,7 @@ def _fast_prepare_inputs_for_generation(
580581
**kwargs,):
581582
# Overwitten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
582583
empty_past_kv = past_key_values is None
583-
584+
584585
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
585586
# Exception 1: when passing input_embeds, input_ids may be missing entries
586587
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here

0 commit comments

Comments
 (0)