Skip to content

Commit d9c1b3e

Browse files
author
Vinayyyy7
committed
Fresh Multi-GPU device consistency fixes for fused kernels
1 parent 8043908 commit d9c1b3e

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

unsloth_zoo/fused_losses/cross_entropy_loss.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,15 @@ def forward(
172172
device = lm_head_weight.device
173173
if extra_kwargs is None: extra_kwargs = {}
174174

175+
# Fix for multi-GPU: ensure all tensors are on the same device for computation
176+
# torch.func.grad_and_value fails when tensors are on different devices
177+
# BUT we must return gradients on the ORIGINAL device of hidden_states
178+
original_hidden_states_device = hidden_states.device
179+
if hidden_states.device != device:
180+
hidden_states = hidden_states.to(device)
181+
if labels.device != device:
182+
labels = labels.to(device)
183+
175184
# Get shifted labels first
176185
if shift_labels:
177186
_labels = torch.empty_like(labels, device = device)
@@ -328,6 +337,7 @@ def accumulate_chunk(
328337
pass
329338
ctx.save_for_backward(grad_inputs, grad_lm_head, grad_lm_head_bias)
330339
ctx.scaling = scaling
340+
ctx.original_hidden_states_device = original_hidden_states_device
331341
return accumulated_loss
332342
pass
333343

@@ -338,6 +348,10 @@ def backward(ctx, grad_output,):
338348
scaling = ctx.scaling if ctx.scaling is not None else 1.0
339349
torch._assert(torch.all(grad_output == scaling), f"Fused losses expect grad_output to be all {scaling}, but got {grad_output.ravel()[:10]}")
340350
(grad_inputs, grad_lm_head, grad_lm_head_bias, ) = ctx.saved_tensors
351+
# Fix for multi-GPU: return gradients on the ORIGINAL device of hidden_states
352+
original_device = ctx.original_hidden_states_device
353+
if grad_inputs.device != original_device:
354+
grad_inputs = grad_inputs.to(original_device)
341355
return (None, grad_inputs, grad_lm_head, grad_lm_head_bias, None, None, None, None, None, None, None, None, None,)
342356
pass
343357
pass

unsloth_zoo/rl_replacements.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def chunked_hidden_states_selective_log_softmax(
8383
all_per_token_logps = []
8484

8585
for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
86-
chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
86+
# Fix for multi-GPU: ensure all tensors are on the same device
87+
chunk_logits = chunk_hidden_states.to(lm_head.device).to(lm_head.dtype) @ lm_head.t()
8788

8889
if logit_scale_multiply != 0.0:
8990
chunk_logits = chunk_logits * logit_scale_multiply

0 commit comments

Comments
 (0)