Fix TRL 0.25.1+ GRPO vision crash and reward function TypeError#3975
Fix TRL 0.25.1+ GRPO vision crash and reward function TypeError#3975danielhanchen wants to merge 2 commits intomainfrom
Conversation
Fix 2 (rl.py): Add _patch_prepare_multimodal_messages() - Wraps prepare_multimodal_messages with isinstance(messages, str) guard - Fixes vision GRPO crash when notebooks pre-apply chat templates - String prompts now pass through unchanged Fix 6 (rl_replacements.py): Add grpo_trainer__calculate_rewards_text_fix() - Makes _calculate_rewards use prompts_text/completions_text for TRL 0.25.0+ - Ensures reward functions receive consistent plain text format - Fixes TypeError when reward functions expect strings but get dicts
for more information, see https://pre-commit.ci
Summary of ChangesHello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces two crucial compatibility fixes for the TRL library versions 0.25.0 and newer. It resolves issues where pre-applied chat templates led to crashes in vision models and eliminates Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces two important fixes for compatibility with recent versions of TRL, addressing a crash in vision models with pre-applied chat templates and a TypeError in the reward function due to inconsistent data formats. The changes are well-implemented and follow the existing patching patterns in the codebase. The logic is sound and directly solves the described issues. I have one suggestion to improve the robustness of a string replacement by using a regular expression, which is consistent with other parts of the code.
| function = function.replace( | ||
| "self._calculate_rewards(inputs, prompts, completions, completion_ids_list)", | ||
| "self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)", | ||
| ) |
There was a problem hiding this comment.
The use of function.replace for patching the _calculate_rewards call is a bit fragile as it's sensitive to whitespace changes in the upstream TRL library. Using re.sub would make this patch more robust against minor formatting variations. This is also consistent with how other patches are implemented in this file.
| function = function.replace( | |
| "self._calculate_rewards(inputs, prompts, completions, completion_ids_list)", | |
| "self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)", | |
| ) | |
| function = re.sub( | |
| r"self\._calculate_rewards\(\s*inputs,\s*prompts,\s*completions,\s*completion_ids_list\s*\)", | |
| "self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)", | |
| function, | |
| ) |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7e7184e783
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| def _safe_prepare_multimodal_messages(messages, *args, **kwargs): | ||
| # If messages is already a string (pre-applied chat template), return as-is | ||
| if isinstance(messages, str): | ||
| return messages |
There was a problem hiding this comment.
Preserve image placeholders when messages are strings
The new guard returns the original string whenever messages is a str, regardless of any images passed via *args/**kwargs. In TRL, prepare_multimodal_messages is responsible for inserting image placeholders into the prompt; bypassing it means any vision inputs provided alongside a pre-applied chat template are silently ignored, so GRPO will train on text-only prompts. This regression shows up whenever a notebook pre-applies apply_chat_template() but still supplies images to the vision model.
Useful? React with 👍 / 👎.
| function = function.replace( | ||
| "self._calculate_rewards(inputs, prompts, completions, completion_ids_list)", | ||
| "self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)", |
There was a problem hiding this comment.
Keep structured prompts available for reward functions
This replacement forces _calculate_rewards to always receive prompts_text/completions_text even in conversational mode. Reward functions that intentionally inspect message structure (e.g., roles, tool-call fields, or metadata in dicts) will now receive flattened strings and can no longer operate correctly. That’s a behavioral regression for existing conversational reward functions on TRL 0.25.0+.
Useful? React with 👍 / 👎.
Summary
Changes
Fix 2: Vision GRPO crash (rl.py)
TRL 0.25.1+ calls
prepare_multimodal_messages()unconditionally for vision models. When notebooks pre-applytokenizer.apply_chat_template()(converting prompts to strings), the function crashes iterating over characters.Solution: Add
_patch_prepare_multimodal_messages()that wraps the TRL function with anisinstance(messages, str)guard. String prompts now pass through unchanged.Fix 6: Reward function TypeError (rl_replacements.py)
TRL 0.25.0+ passes
promptsandcompletionsto_calculate_rewardsin different formats:[{"role": "assistant", "content": "..."}]This inconsistency causes reward functions to crash when they expect strings but receive dicts (or vice versa).
Solution: Add
grpo_trainer__calculate_rewards_text_fix()that makes_calculate_rewardsalways useprompts_textandcompletions_text(plain decoded strings) for consistent behavior.Test plan
prepare_multimodal_messages("test string", [])returns string unchanged_calculate_rewardsusesprompts_text, completions_text