-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Fix TRL 0.25.1+ GRPO vision crash and reward function TypeError #3975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -577,6 +577,35 @@ def grpo_trainer__generate_and_score_completions(function_name, function): | |||||||||||||||||||
| RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__generate_and_score_completions) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| # Fix 6: TRL 0.25.0+ _calculate_rewards text arguments | ||||||||||||||||||||
| # | ||||||||||||||||||||
| # TRL 0.25.0+ passes `prompts` and `completions` to _calculate_rewards in different formats: | ||||||||||||||||||||
| # - For conversational inputs: list of dicts [{"role": "assistant", "content": "..."}] | ||||||||||||||||||||
| # - For non-conversational inputs: plain text strings | ||||||||||||||||||||
| # | ||||||||||||||||||||
| # This inconsistency causes reward functions to fail when they expect one format but get the other. | ||||||||||||||||||||
| # The variables `prompts_text` and `completions_text` always contain plain decoded text strings. | ||||||||||||||||||||
| # | ||||||||||||||||||||
| # Fix: Always pass plain text (prompts_text, completions_text) to _calculate_rewards for consistency. | ||||||||||||||||||||
| # This ensures reward functions receive predictable string format regardless of conversational mode. | ||||||||||||||||||||
| def grpo_trainer__calculate_rewards_text_fix(function_name, function): | ||||||||||||||||||||
| if function_name != "_generate_and_score_completions": | ||||||||||||||||||||
| return function | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Only apply if prompts_text and completions_text exist (TRL 0.25.0+) | ||||||||||||||||||||
| if "prompts_text" in function and "completions_text" in function: | ||||||||||||||||||||
| # Replace the _calculate_rewards call to use text versions | ||||||||||||||||||||
| function = function.replace( | ||||||||||||||||||||
| "self._calculate_rewards(inputs, prompts, completions, completion_ids_list)", | ||||||||||||||||||||
| "self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)", | ||||||||||||||||||||
|
Comment on lines
+598
to
+600
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This replacement forces Useful? React with 👍 / 👎. |
||||||||||||||||||||
| ) | ||||||||||||||||||||
|
Comment on lines
+598
to
+601
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of
Suggested change
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| return function | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__calculate_rewards_text_fix) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| # Fix {"reasoning_effort" : "high"} not applied | ||||||||||||||||||||
| def grpo_trainer_fix_maybe_apply_chat_template(function_name, function): | ||||||||||||||||||||
| spaces = function.find("def ") | ||||||||||||||||||||
|
|
||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new guard returns the original string whenever
messagesis astr, regardless of any images passed via*args/**kwargs. In TRL,prepare_multimodal_messagesis responsible for inserting image placeholders into the prompt; bypassing it means any vision inputs provided alongside a pre-applied chat template are silently ignored, so GRPO will train on text-only prompts. This regression shows up whenever a notebook pre-appliesapply_chat_template()but still supplies images to the vision model.Useful? React with 👍 / 👎.