Skip to content

Comments

Add vLLM fallback and GRPO completion normalization#3958

Open
danielhanchen wants to merge 2 commits intomainfrom
fix/compat-guards
Open

Add vLLM fallback and GRPO completion normalization#3958
danielhanchen wants to merge 2 commits intomainfrom
fix/compat-guards

Conversation

@danielhanchen
Copy link
Contributor

Summary

  • add safe vLLM import guard for guided decoding patch
  • normalize GRPO completions and retry reward funcs on string vs dict mismatch
  • wrap vLLM fast_generate with HF fallback and disable vLLM inference for FP8 models
  • attach HF model and tokenizer to vLLM engine for fallback
  • add DeepSeek v2 MoE alias and Qwen VL compatibility helpers
  • guard RL trainer compilation to fall back to original trainer on failure

Testing

  • python -m py_compile unsloth/import_fixes.py unsloth/init.py unsloth/models/rl_replacements.py unsloth/models/_utils.py unsloth/models/llama.py unsloth/models/vision.py unsloth/models/rl.py

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the stability and compatibility of the Unsloth library, particularly concerning vLLM integration and Reinforcement Learning (RL) trainers. It introduces mechanisms to gracefully handle potential failures in vLLM inference and improves the flexibility of reward function processing in GRPO. Additionally, it addresses specific model compatibility issues, ensuring a smoother user experience across a wider range of models and configurations.

Highlights

  • vLLM Fallback Mechanism: Implemented a robust fallback system for vLLM's fast_generate to seamlessly switch to Hugging Face's generate if vLLM encounters an error or is explicitly disabled (e.g., for FP8 models). This ensures continued inference capability even when vLLM is not optimal or fails.
  • GRPO Completion Normalization: Enhanced GRPO reward functions to normalize completions, converting non-string outputs (like dictionaries or lists) into a consistent string format. This includes a retry mechanism to handle reward functions that might expect dictionary-based completions, improving compatibility and robustness.
  • Improved vLLM Import Guarding: Strengthened the import guards for vLLM's GuidedDecodingParams patch, making it more resilient to various import errors and providing clearer warnings when the patch cannot be applied.
  • Model Compatibility Fixes: Added compatibility fixes for specific models, including an alias for DeepSeek v2 MoE and adjustments for Qwen VL models regarding max_pixels and tie_word_embeddings.
  • RL Trainer Compilation Safeguard: Introduced a safeguard for RL trainer compilation, allowing the system to fall back to the original trainer if the custom compilation process fails, preventing crashes and improving stability.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 697b5005c1

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +2521 to +2526
def _vllm_fallback_to_hf_generate(model, args, kwargs):
tokenizer = getattr(model, "_saved_temp_tokenizer", None)
hf_kwargs = dict(kwargs)
sampling_params = hf_kwargs.pop("sampling_params", None)
hf_kwargs.pop("lora_request", None)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Map/remove vLLM prompts kw in HF fallback

When vLLM falls back to HF, _vllm_fallback_to_hf_generate copies kwargs verbatim and only removes sampling_params/lora_request. If the caller used the common vLLM signature fast_generate(prompts=...), the fallback passes an unexpected prompts kw into model.generate, which raises a TypeError because HF generate doesn’t accept that argument. This only appears when vLLM fails or inference is disabled, but it will reliably break those fallback calls unless prompts is converted into tokenized input_ids or removed.

Useful? React with 👍 / 👎.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several valuable improvements, notably a robust fallback mechanism for vLLM inference and normalization for GRPO completions. The added compatibility fixes for various models and the defensive programming approach to imports and dynamic code generation enhance the library's stability.

I've identified a couple of areas for improvement:

  1. The vLLM fallback logic has a hardcoded CUDA device, which should be generalized for other hardware.
  2. A logging statement in the RL trainer patch could be updated to use the project's logger for consistency.

Overall, these are great changes that make Unsloth more resilient and user-friendly.

return_tensors = "pt",
padding = True,
)
device = getattr(model, "device", None) or "cuda"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fallback device is hardcoded to "cuda". This might cause issues on non-NVIDIA hardware like AMD (ROCm/hip) or Intel (XPU) GPUs. It's better to use the dynamically determined DEVICE_TYPE_TORCH for broader compatibility.

Suggested change
device = getattr(model, "device", None) or "cuda"
device = getattr(model, "device", None) or DEVICE_TYPE_TORCH

Comment on lines +1143 to +1144
if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1":
print(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the rest of the codebase, it's better to use the configured logger for logging messages instead of print. This allows users to control log levels and formatting centrally.

Suggested change
if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1":
print(
if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1":
logger.warning(f"Unsloth: Failed to compile {RLTrainer_name} ({exc}), falling back to original trainer.")

return _fast_generate_wrapper


def make_vllm_fast_generate_wrapper(model, vllm_generate):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what cases is the fallback triggered ? What case did you observe vllm inference failing?

"Unsloth: vLLM fast_generate failed and no tokenizer was cached for HF fallback."
)
inputs = tokenizer(
first_arg,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can go wrong very easily. if we're passing something to tokenizer, we better check what arg it is at least?

)

function = function.replace(
' reward_kwargs["trainer_state"] = self.state\n',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we tracking this?

" return ''.join(_unsloth_completion_to_text(item) for item in completion)\n"
" return str(completion)\n"
" completion_texts = [_unsloth_completion_to_text(c) for c in completions]\n"
" completions_are_text = all(isinstance(c, str) for c in completions)\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What else is a possibility? When are tokens not text for LLM/VLM GRPO?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants