Skip to content

Commit 9e854a9

Browse files
author
Vinayyyy7
committed
Fresh Multi-GPU DDP fixes for Vision and RL
1 parent b036191 commit 9e854a9

File tree

4 files changed

+58
-2
lines changed

4 files changed

+58
-2
lines changed

unsloth-zoo

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 8043908581dcd07bbd4e441459da918950dd43f3

unsloth/models/loader.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -846,9 +846,29 @@ def from_pretrained(
846846
# In multi-GPU (torchrun), each rank must load the model on its own device
847847
# to avoid Accelerate device relocation errors with quantized weights.
848848
is_quantized = load_in_4bit or load_in_8bit or load_in_fp8
849-
if is_quantized and isinstance(device_map, str):
849+
if isinstance(device_map, str):
850850
distributed_device_map, is_dist = prepare_device_map()
851-
if is_dist:
851+
if (is_dist or DEVICE_COUNT > 1) and device_map in ("auto", "balanced", "balanced_low_0"):
852+
import warnings
853+
if is_dist:
854+
raise ValueError(
855+
f"Unsloth: You are in a distributed training environment (multi-GPU) but used device_map='{device_map}'.\n"
856+
f"Model splitting across GPUs is not supported as it causes gradient device mismatches with Unsloth's fused kernels.\n"
857+
f"Please set `device_map = None` to enable standard Data Parallelism.\n"
858+
f"Note: This will load a full copy of the model on each GPU.\n"
859+
f"This uses more VRAM per GPU but provides equivalent training to single GPU."
860+
)
861+
else:
862+
# Non-distributed multi-GPU case
863+
warnings.warn(
864+
f"Unsloth: You have {DEVICE_COUNT} GPUs but used device_map='{device_map}'.\n"
865+
f"Model splitting across GPUs is not yet supported with Unsloth's fused kernels.\n"
866+
f"We will override this to use only the first GPU to prevent crashes.\n"
867+
f"To use both GPUs for training, please use `accelerate launch` or `torchrun` and set `device_map=None` for Data Parallelism.",
868+
stacklevel = 2,
869+
)
870+
device_map = {"" : "cuda:0"}
871+
elif is_dist and is_quantized:
852872
device_map = distributed_device_map
853873

854874
# Check if 4bit is allowed specifically for AMD

unsloth/models/rl_replacements.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,10 @@ def _get_per_token_logps_and_entropies(
675675
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
676676
self._autocast_dtype = torch.float16
677677

678+
# Fix for GRPO DDP Multi-GPU: Unwrap model to access attributes like .config
679+
if torch.distributed.is_initialized() and hasattr(model, "module"):
680+
model = model.module
681+
678682
pixel_values, image_grid_thw = (
679683
kwargs.get("pixel_values", None),
680684
kwargs.get("image_grid_thw", None),
@@ -923,6 +927,9 @@ def grpo_trainer_compute_loss(function_name, function):
923927
def compute_loss(
924928
self, model, inputs, return_outputs = False, num_items_in_batch = None
925929
):
930+
if torch.distributed.is_initialized() and hasattr(model, "module"):
931+
model = model.module
932+
926933
if return_outputs:
927934
raise ValueError("The GRPOTrainer does not support returning outputs")
928935
# Compute the per-token log probabilities for the model

unsloth/models/vision.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,34 @@ def from_pretrained(
477477
"Are you certain you want to do remote code execution?"
478478
)
479479
token = hf_login(token)
480+
481+
# Fix for multi-GPU distributed training
482+
# When using distributed training (e.g., 2x T4 on Kaggle), device_map="auto"/"balanced"
483+
# splits the model across GPUs which can cause gradient device mismatch errors.
484+
# Instead, use data-parallel approach where each GPU gets a full model copy.
485+
from .loader_utils import prepare_device_map, is_distributed
486+
487+
if (is_distributed() or DEVICE_COUNT > 1) and device_map in ("auto", "balanced", "balanced_low_0"):
488+
import warnings
489+
if is_distributed():
490+
raise ValueError(
491+
f"Unsloth: You are in a distributed training environment (multi-GPU) but used device_map='{device_map}'.\n"
492+
f"Model splitting across GPUs is not supported for Vision Models as it causes gradient device mismatches with Unsloth's fused kernels.\n"
493+
f"Please set `device_map = None` to enable standard Data Parallelism.\n"
494+
f"Note: This will load a full copy of the model on each GPU.\n"
495+
f"This uses more VRAM per GPU but provides equivalent training to single GPU."
496+
)
497+
else:
498+
# Non-distributed multi-GPU case
499+
warnings.warn(
500+
f"Unsloth: You have {DEVICE_COUNT} GPUs but used device_map='{device_map}'.\n"
501+
f"Model splitting across GPUs is not yet supported for Vision Models with Unsloth's fused kernels.\n"
502+
f"We will override this to use only the first GPU to prevent crashes.\n"
503+
f"To use both GPUs for training, please use `accelerate launch` or `torchrun` and set `device_map=None` for Data Parallelism.",
504+
stacklevel = 2,
505+
)
506+
device_map = {"" : "cuda:0"}
507+
480508
SUPPORTS_BFLOAT16 = is_bfloat16_supported()
481509

482510
if DEVICE_TYPE == "cuda":

0 commit comments

Comments
 (0)