@@ -172,6 +172,15 @@ def forward(
172172 device = lm_head_weight .device
173173 if extra_kwargs is None : extra_kwargs = {}
174174
175+ # Fix for multi-GPU: ensure all tensors are on the same device for computation
176+ # torch.func.grad_and_value fails when tensors are on different devices
177+ # BUT we must return gradients on the ORIGINAL device of hidden_states
178+ original_hidden_states_device = hidden_states .device
179+ if hidden_states .device != device :
180+ hidden_states = hidden_states .to (device )
181+ if labels .device != device :
182+ labels = labels .to (device )
183+
175184 # Get shifted labels first
176185 if shift_labels :
177186 _labels = torch .empty_like (labels , device = device )
@@ -328,6 +337,7 @@ def accumulate_chunk(
328337 pass
329338 ctx .save_for_backward (grad_inputs , grad_lm_head , grad_lm_head_bias )
330339 ctx .scaling = scaling
340+ ctx .original_hidden_states_device = original_hidden_states_device
331341 return accumulated_loss
332342 pass
333343
@@ -338,6 +348,10 @@ def backward(ctx, grad_output,):
338348 scaling = ctx .scaling if ctx .scaling is not None else 1.0
339349 torch ._assert (torch .all (grad_output == scaling ), f"Fused losses expect grad_output to be all { scaling } , but got { grad_output .ravel ()[:10 ]} " )
340350 (grad_inputs , grad_lm_head , grad_lm_head_bias , ) = ctx .saved_tensors
351+ # Fix for multi-GPU: return gradients on the ORIGINAL device of hidden_states
352+ original_device = ctx .original_hidden_states_device
353+ if grad_inputs .device != original_device :
354+ grad_inputs = grad_inputs .to (original_device )
341355 return (None , grad_inputs , grad_lm_head , grad_lm_head_bias , None , None , None , None , None , None , None , None , None ,)
342356 pass
343357pass
0 commit comments