Skip to content

Comments

fix: multi-GPU training support for vision models#3809

Closed
Vinayyyy7 wants to merge 2 commits intounslothai:mainfrom
Vinayyyy7:fix/multi-gpu-vision-model-training
Closed

fix: multi-GPU training support for vision models#3809
Vinayyyy7 wants to merge 2 commits intounslothai:mainfrom
Vinayyyy7:fix/multi-gpu-vision-model-training

Conversation

@Vinayyyy7
Copy link

@Vinayyyy7 Vinayyyy7 commented Dec 31, 2025

This PR fixes multi-GPU training for vision models when using device_map="auto" or device_map="balanced".

when running on multiple GPUs, setting device_map="auto/balanced" with FastVisionModel Class causes model to be split across devices. During training, this results in hidden_states being on one GPU (e.g., cuda:1) while lm_head is on another (eg: cuda:0). The fused cross-entropy loss then computes gradients on the lm_head device but PyTorch expects them back on the original hidden_states device, causing a RuntimeError.

Errors we see:

Unsupported: NotImplementedError/UnsupportedFakeTensorException when running FX node
  Explanation: Dynamo failed to run FX node with fake tensors: call_function <function _autograd_grad at 0x7adc2d2d8180>(*((GradTrackingTensor(lvl=1, value=
        FakeTensor(..., device='cuda:0', size=())
    ),), [GradTrackingTensor(lvl=1, value=
        FakeTensor(..., device='cuda:1', size=(s97, 2048), dtype=torch.float16,
                   requires_grad=True)
    )]), **{'create_graph': True}): got NotImplementedError('Cannot access storage of TensorWrapper')
  Hint: If the op is a PyTorch op, please file an issue to PyTorch.

  Developer debug context: 

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0087.html

from user code:
   File "/usr/local/lib/python3.11/dist-packages/unsloth_zoo/fused_losses/cross_entropy_loss.py", line 276, in accumulate_chunk
    (chunk_loss, (unscaled_loss,)) = torch.func.grad_and_value(
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/apis.py", line 449, in wrapper
    return eager_transforms.grad_and_value_impl(
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/vmap.py", line 47, in fn
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/eager_transforms.py", line 1390, in grad_and_value_impl
    flat_grad_input = _autograd_grad(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

OR

[rank1]: ValueError: You can't train a model that has been loaded in 8-bit or 4-bit precision on a different device than the one you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device()}` or `device_map={'':torch.xpu.current_device()}`

this fix adds distributed training detection in FastBaseModel.from_pretrained(). when distributed training is detected and device_map is set to auto/balanced, it it gives a ValueError saying to use data-parallel mode where each GPU loads a full copy of the model.

Note: This PR works together with a corresponding fix in unsloth-zoo that handles the gradient device mismatch in the fused CE loss.

unslothai/unsloth-zoo#423

Tested on Kaggle with 2x T4 GPUs using Qwen/Qwen3-VL-2B-Instruct
Result: IT WORKED

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 2
   \\   /|    Num examples = 10,000 | Num Epochs = 1 | Total steps = 1,250
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 139,460,608 of 2,266,992,640 (6.15% trained)
Step | Training Loss
-- | --
1 | 46.795200
2 | 109.610000
3 | 40.122600
4 | 44.093400
5 | 32.455600
6 | 47.650400
7 | 60.632800
8 | 24.140800
9 | 63.971800
10 | 19.096700
11 | 20.428400
12 | 17.989900
13 | 10.218600
14 | 20.911900
15 | 16.450800
16 | 6.303300
17 | 4.037200
18 | 4.878000
19 | 2.447200
20 | 2.050200

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Vinayyyy7, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves critical issues preventing successful multi-GPU training for vision models when using automatic or balanced device mapping. By intelligently detecting distributed training environments, the system now defaults to a data-parallel strategy, ensuring each GPU receives a complete model copy. This prevents model splitting across devices, which previously led to gradient calculation errors and RuntimeError exceptions, thereby enabling robust and efficient multi-GPU training for vision models.

Highlights

  • Multi-GPU Training Fix: Addresses RuntimeError and ValueError encountered during multi-GPU training of vision models when device_map is set to "auto" or "balanced", which caused model splitting and gradient device mismatches.
  • Device Map Override: Implements a mechanism to detect distributed training and, if device_map is "auto", "balanced", or "balanced_low_0", automatically switches to a data-parallel approach where each GPU loads a full model copy, preventing splitting issues.
  • Warning for Device Map: Adds a warning to inform users when the device_map is overridden for distributed vision model training, explaining why data-parallel mode is being used instead of device splitting.
  • Integration with unsloth-zoo: Notes that this PR works in conjunction with a corresponding fix in unsloth-zoo to handle gradient device mismatches in the fused cross-entropy loss.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical issue with multi-GPU training for vision models by correctly handling device_map in distributed environments. The fix prevents model splitting across GPUs, which was causing runtime errors. The overall approach is sound. However, the implementation introduces code duplication between FastModel.from_pretrained and FastBaseModel.from_pretrained. My review includes suggestions to refactor this by removing the redundant code and centralizing the logic in the base model, which will improve maintainability and efficiency.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@Datta0
Copy link
Collaborator

Datta0 commented Jan 1, 2026

Hey @Vinayyyy7 thanks for the contribution
I'm not 100% sure if switching to DataParallel is what users would want when they set device_map=balanced.
So maybe we should error out instead of warning so that they can redo with appropriate device_map? We can say something like

"device map balanced/auto is not supported for FastVisionModel and we'd recommend using Data Parallel by setting device_map=xyz. Note that this uses higher memory than balanced/auto on each gpu but equivalent to running single gpu training"

thoughts @danielhanchen ?

@Vinayyyy7
Copy link
Author

agree that erroring out is cleaner than silently overriding. Users should explicitly knoww what's happening.

clarifications:

Splitting model across GPUs always crashed due to gradient mismatch errors, companion PR in unsloth-zoo (unslothai/unsloth-zoo#423) handles this by tracking the original device and moving gradients back. But even with that fix, model splitting has other issues with distributed training and quantization

What device_map we can suggest users:
For data-parallel training where each GPU loads a full model copy and which actually works, users should use:

  • device_map=None # for multi gpu
  • OR
  • explicitly device_map={"": "cuda:0"} # for single GPU

Proposed error message:

raise ValueError(
    f"Unsloth: device_map='{device_map}' is not supported for FastVisionModel in multi-GPU training. "
    f"Model splitting across GPUs causes gradient device mismatch errors. "
    f"For multi-GPU training, remove device_map or set device_map=None to use data-parallel mode "
    f"where each GPU loads a full copy of the model. "
    f"Note: This uses more VRAM per GPU but provides equivalent training to single GPU."
)

Is this warning and device_map=None good suggestion? so that if users have enough VRAM they can do training 2x faster on 2 GPUs instead of keeping 1 empty, most users might prefer kaggle as it provides 12 hours of continuous GPU while colab only 4-5 hours and only 1 GPU.

@Datta0

@Datta0
Copy link
Collaborator

Datta0 commented Jan 1, 2026

Hey @Vinayyyy7
Yeah the error looks fine. Though I too am not 100% sure what device_map = none ends up doing. I can't seem to readily find it either.

@Vinayyyy7
Copy link
Author

it's just an if condition since we might warn users saying auto/balaned not supported in FastVisionModel like above example error, telling them you can use data parallel by device map None. this way they know it requires more VRAM but would work and faster compared to 1 GPU

note: right know this is not implemented in PR, rn it just directly does data parallel if detects auto/balanced.

@Vinayyyy7 Vinayyyy7 force-pushed the fix/multi-gpu-vision-model-training branch from 3a47492 to 2550f94 Compare January 6, 2026 09:46
@Vinayyyy7
Copy link
Author

SOME UPDATES as tested by another user, their issues were addressed

Fixes

  • GRPO Multi-GPU Fix (unsloth/models/rl_replacements.py): I added the if torch.distributed.is_initialized(): model = model.module check. this resolves the error crash for GRPO.:
AttributeError: 'DistributedDataParallel' ... has no attribute 'config'
  • DDP Safety Enforced (loader.py & vision.py): changed the warning to a strict ValueError. If a user tries device_map="auto/balanced" in a distributed env (that causes the model splitting/gradient mismatch crash), it will now explicitly error out and instruct them to use device_map=None for Data Parallelism. (Multi-GPU fix)

@Vinayyyy7
Copy link
Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces fixes to support multi-GPU training for vision models. The change in unsloth/models/rl_replacements.py correctly unwraps the model in a distributed setting, which is a good fix. The main change in unsloth/models/vision.py adds a check to prevent model splitting across GPUs, which is known to cause issues. My primary feedback is on the implementation of this check. Instead of raising an error, which requires manual intervention from the user, I recommend automatically overriding the device_map and issuing a warning. This would be more user-friendly and aligns better with the pull request's description of overriding the setting.

@Vinayyyy7
Copy link
Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fix for multi-GPU training with vision models by preventing the use of device_map settings that lead to model splitting, which is unsupported. The changes correctly raise an informative error to guide users. Additionally, it includes patches to unwrap DDP models for GRPO training. My review focuses on ensuring the robustness of these DDP patches. I've identified one potential issue where a model is unwrapped without a necessary safety check, which could lead to runtime errors.

Copy link
Collaborator

@Datta0 Datta0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay. I somehow forgot about this PR

@alien087
Copy link

I have a similar problem when I want to train qwen3 vl 30B on 2x L40s, I want to split the model to fit on my 2 GPUs but the exact error here appears, does this solve the problem?

@Vinayyyy7
Copy link
Author

I have a similar problem when I want to train qwen3 vl 30B on 2x L40s, I want to split the model to fit on my 2 GPUs but the exact error here appears, does this solve the problem?

Yes, I have personally tested it with qwen3-vl models and internVL models it worked on multi-gpu

another user reported they had some issues with GRPO I remember so those were also fixed, they did some detailed tests which showed DDP worked out

@alien087
Copy link

I have a similar problem when I want to train qwen3 vl 30B on 2x L40s, I want to split the model to fit on my 2 GPUs but the exact error here appears, does this solve the problem?

Yes, I have personally tested it with qwen3-vl models and internVL models it worked on multi-gpu

another user reported they had some issues with GRPO I remember so those were also fixed, they did some detailed tests which showed DDP worked out

so it'll work for model sharding too? because i'll split the model into 2 GPUs because i can't fit on 1xL40s, so i want to split the model into 2x L40s? Sorry for asking a lot of questions, I just want to make sure before trying it.

@Vinayyyy7
Copy link
Author

Vinayyyy7 commented Feb 16, 2026

I have a similar problem when I want to train qwen3 vl 30B on 2x L40s, I want to split the model to fit on my 2 GPUs but the exact error here appears, does this solve the problem?

Yes, I have personally tested it with qwen3-vl models and internVL models it worked on multi-gpu
another user reported they had some issues with GRPO I remember so those were also fixed, they did some detailed tests which showed DDP worked out

so it'll work for model sharding too? because i'll split the model into 2 GPUs because i can't fit on 1xL40s, so i want to split the model into 2x L40s? Sorry for asking a lot of questions, I just want to make sure before trying it.

if you are trying to do model-parallelism like sharding the model into multiple GPUs (that's what unsloth do for normal text-genration models which works nicely) but for vision models like qwen etc it does not support it, (might do another PR to support that in future)

So this implementation is made to make it work, it use data-parallelism instead as a little trade-off it splits training examples from dataset, although it loads the model entirely on both GPUs the training is faster compared to single GPU

setting device_map=None enables this (accelerate library behavior)

I would recommened that have GPUs that can fit the model fully, I know this approch is kinda dumb but good at the same time

@mmathew23
Copy link
Collaborator

@Vinayyyy7 I just tried testing this PR with qwen3 vl gspo, and was still getting an error. Can you show me the script you used to test?

@Vinayyyy7 Vinayyyy7 force-pushed the fix/multi-gpu-vision-model-training branch from fec27be to 9e854a9 Compare February 18, 2026 12:59
@Vinayyyy7
Copy link
Author

Yes this PR was made few days before unsloth released GSPO training support, and since then a lot of changes were made + different migration dependency updates stuff, so the code likely broke not even my old training code is working right now,

SO I've decided to start fresh and apply the fix again

@mmathew23
Copy link
Collaborator

Thanks! Should we close this one and start a fresh one?

@Vinayyyy7
Copy link
Author

Thanks! Should we close this one and start a fresh one?

Sure, a NEW PR would be better

@Vinayyyy7 Vinayyyy7 closed this Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants