Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions unsloth-zoo
Submodule unsloth-zoo added at 804390
29 changes: 27 additions & 2 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,9 +846,34 @@ def from_pretrained(
# In multi-GPU (torchrun), each rank must load the model on its own device
# to avoid Accelerate device relocation errors with quantized weights.
is_quantized = load_in_4bit or load_in_8bit or load_in_fp8
if is_quantized and isinstance(device_map, str):
if isinstance(device_map, str):
distributed_device_map, is_dist = prepare_device_map()
if is_dist:
if (is_dist or DEVICE_COUNT > 1) and device_map in (
"auto",
"balanced",
"balanced_low_0",
):
import warnings

if is_dist:
raise ValueError(
f"Unsloth: You are in a distributed training environment (multi-GPU) but used device_map='{device_map}'.\n"
f"Model splitting across GPUs is not supported as it causes gradient device mismatches with Unsloth's fused kernels.\n"
f"Please set `device_map = None` to enable standard Data Parallelism.\n"
f"Note: This will load a full copy of the model on each GPU.\n"
f"This uses more VRAM per GPU but provides equivalent training to single GPU."
)
else:
# Non-distributed multi-GPU case
warnings.warn(
f"Unsloth: You have {DEVICE_COUNT} GPUs but used device_map='{device_map}'.\n"
f"Model splitting across GPUs is not yet supported with Unsloth's fused kernels.\n"
f"We will override this to use only the first GPU to prevent crashes.\n"
f"To use both GPUs for training, please use `accelerate launch` or `torchrun` and set `device_map=None` for Data Parallelism.",
stacklevel = 2,
)
device_map = {"": "cuda:0"}
elif is_dist and is_quantized:
device_map = distributed_device_map

# Check if 4bit is allowed specifically for AMD
Expand Down
7 changes: 7 additions & 0 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,10 @@ def _get_per_token_logps_and_entropies(
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
self._autocast_dtype = torch.float16

# Fix for GRPO DDP Multi-GPU: Unwrap model to access attributes like .config
if torch.distributed.is_initialized() and hasattr(model, "module"):
model = model.module

pixel_values, image_grid_thw = (
kwargs.get("pixel_values", None),
kwargs.get("image_grid_thw", None),
Expand Down Expand Up @@ -923,6 +927,9 @@ def grpo_trainer_compute_loss(function_name, function):
def compute_loss(
self, model, inputs, return_outputs = False, num_items_in_batch = None
):
if torch.distributed.is_initialized() and hasattr(model, "module"):
model = model.module

if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Compute the per-token log probabilities for the model
Expand Down
33 changes: 33 additions & 0 deletions unsloth/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,39 @@ def from_pretrained(
"Are you certain you want to do remote code execution?"
)
token = hf_login(token)

# Fix for multi-GPU distributed training
# When using distributed training (e.g., 2x T4 on Kaggle), device_map="auto"/"balanced"
# splits the model across GPUs which can cause gradient device mismatch errors.
# Instead, use data-parallel approach where each GPU gets a full model copy.
from .loader_utils import prepare_device_map, is_distributed

if (is_distributed() or DEVICE_COUNT > 1) and device_map in (
"auto",
"balanced",
"balanced_low_0",
):
import warnings

if is_distributed():
raise ValueError(
f"Unsloth: You are in a distributed training environment (multi-GPU) but used device_map='{device_map}'.\n"
f"Model splitting across GPUs is not supported for Vision Models as it causes gradient device mismatches with Unsloth's fused kernels.\n"
f"Please set `device_map = None` to enable standard Data Parallelism.\n"
f"Note: This will load a full copy of the model on each GPU.\n"
f"This uses more VRAM per GPU but provides equivalent training to single GPU."
)
else:
# Non-distributed multi-GPU case
warnings.warn(
f"Unsloth: You have {DEVICE_COUNT} GPUs but used device_map='{device_map}'.\n"
f"Model splitting across GPUs is not yet supported for Vision Models with Unsloth's fused kernels.\n"
f"We will override this to use only the first GPU to prevent crashes.\n"
f"To use both GPUs for training, please use `accelerate launch` or `torchrun` and set `device_map=None` for Data Parallelism.",
stacklevel = 2,
)
device_map = {"": "cuda:0"}

SUPPORTS_BFLOAT16 = is_bfloat16_supported()

if DEVICE_TYPE == "cuda":
Expand Down