Add vLLM fallback and GRPO completion normalization#3958
Add vLLM fallback and GRPO completion normalization#3958danielhanchen wants to merge 2 commits intomainfrom
Conversation
Summary of ChangesHello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the stability and compatibility of the Unsloth library, particularly concerning vLLM integration and Reinforcement Learning (RL) trainers. It introduces mechanisms to gracefully handle potential failures in vLLM inference and improves the flexibility of reward function processing in GRPO. Additionally, it addresses specific model compatibility issues, ensuring a smoother user experience across a wider range of models and configurations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 697b5005c1
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| def _vllm_fallback_to_hf_generate(model, args, kwargs): | ||
| tokenizer = getattr(model, "_saved_temp_tokenizer", None) | ||
| hf_kwargs = dict(kwargs) | ||
| sampling_params = hf_kwargs.pop("sampling_params", None) | ||
| hf_kwargs.pop("lora_request", None) | ||
|
|
There was a problem hiding this comment.
Map/remove vLLM
prompts kw in HF fallback
When vLLM falls back to HF, _vllm_fallback_to_hf_generate copies kwargs verbatim and only removes sampling_params/lora_request. If the caller used the common vLLM signature fast_generate(prompts=...), the fallback passes an unexpected prompts kw into model.generate, which raises a TypeError because HF generate doesn’t accept that argument. This only appears when vLLM fails or inference is disabled, but it will reliably break those fallback calls unless prompts is converted into tokenized input_ids or removed.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Code Review
This pull request introduces several valuable improvements, notably a robust fallback mechanism for vLLM inference and normalization for GRPO completions. The added compatibility fixes for various models and the defensive programming approach to imports and dynamic code generation enhance the library's stability.
I've identified a couple of areas for improvement:
- The vLLM fallback logic has a hardcoded CUDA device, which should be generalized for other hardware.
- A logging statement in the RL trainer patch could be updated to use the project's logger for consistency.
Overall, these are great changes that make Unsloth more resilient and user-friendly.
| return_tensors = "pt", | ||
| padding = True, | ||
| ) | ||
| device = getattr(model, "device", None) or "cuda" |
There was a problem hiding this comment.
The fallback device is hardcoded to "cuda". This might cause issues on non-NVIDIA hardware like AMD (ROCm/hip) or Intel (XPU) GPUs. It's better to use the dynamically determined DEVICE_TYPE_TORCH for broader compatibility.
| device = getattr(model, "device", None) or "cuda" | |
| device = getattr(model, "device", None) or DEVICE_TYPE_TORCH |
| if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1": | ||
| print( |
There was a problem hiding this comment.
For consistency with the rest of the codebase, it's better to use the configured logger for logging messages instead of print. This allows users to control log levels and formatting centrally.
| if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1": | |
| print( | |
| if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1": | |
| logger.warning(f"Unsloth: Failed to compile {RLTrainer_name} ({exc}), falling back to original trainer.") |
| return _fast_generate_wrapper | ||
|
|
||
|
|
||
| def make_vllm_fast_generate_wrapper(model, vllm_generate): |
There was a problem hiding this comment.
In what cases is the fallback triggered ? What case did you observe vllm inference failing?
| "Unsloth: vLLM fast_generate failed and no tokenizer was cached for HF fallback." | ||
| ) | ||
| inputs = tokenizer( | ||
| first_arg, |
There was a problem hiding this comment.
This can go wrong very easily. if we're passing something to tokenizer, we better check what arg it is at least?
| ) | ||
|
|
||
| function = function.replace( | ||
| ' reward_kwargs["trainer_state"] = self.state\n', |
| " return ''.join(_unsloth_completion_to_text(item) for item in completion)\n" | ||
| " return str(completion)\n" | ||
| " completion_texts = [_unsloth_completion_to_text(c) for c in completions]\n" | ||
| " completions_are_text = all(isinstance(c, str) for c in completions)\n", |
There was a problem hiding this comment.
What else is a possibility? When are tokens not text for LLM/VLM GRPO?
Summary
Testing