Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,10 +1559,48 @@ def patch_trl_openenv():
return


def _patch_prepare_multimodal_messages():
"""Fix 2: TRL >= 0.25.1 calls prepare_multimodal_messages unconditionally for
vision models. When notebooks pre-apply chat templates (converting prompts to strings),
the function crashes iterating over characters. This patch adds isinstance(messages, str)
guard to return strings unchanged."""
try:
import trl.data_utils as _du
except ImportError:
return

_original = getattr(_du, "prepare_multimodal_messages", None)
if _original is None:
return
if getattr(_original, "_unsloth_patched", False):
return

def _safe_prepare_multimodal_messages(messages, *args, **kwargs):
# If messages is already a string (pre-applied chat template), return as-is
if isinstance(messages, str):
return messages
Comment on lines +1578 to +1581

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve image placeholders when messages are strings

The new guard returns the original string whenever messages is a str, regardless of any images passed via *args/**kwargs. In TRL, prepare_multimodal_messages is responsible for inserting image placeholders into the prompt; bypassing it means any vision inputs provided alongside a pre-applied chat template are silently ignored, so GRPO will train on text-only prompts. This regression shows up whenever a notebook pre-applies apply_chat_template() but still supplies images to the vision model.

Useful? React with 👍 / 👎.

return _original(messages, *args, **kwargs)

_safe_prepare_multimodal_messages._unsloth_patched = True
_du.prepare_multimodal_messages = _safe_prepare_multimodal_messages

# Also patch in grpo_trainer module if imported
try:
import trl.trainer.grpo_trainer as _gt

if hasattr(_gt, "prepare_multimodal_messages"):
_gt.prepare_multimodal_messages = _safe_prepare_multimodal_messages
except ImportError:
pass

logger.info("Unsloth: Patched prepare_multimodal_messages with string guard")


def PatchFastRL(algorithm = None, FastLanguageModel = None):
if FastLanguageModel is not None:
PatchRL(FastLanguageModel)
patch_trl_rl_trainers()
patch_trl_openenv()
_patch_prepare_multimodal_messages()
if type(algorithm) is str and algorithm.islower():
PatchRLStatistics(algorithm)
29 changes: 29 additions & 0 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,35 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__generate_and_score_completions)


# Fix 6: TRL 0.25.0+ _calculate_rewards text arguments
#
# TRL 0.25.0+ passes `prompts` and `completions` to _calculate_rewards in different formats:
# - For conversational inputs: list of dicts [{"role": "assistant", "content": "..."}]
# - For non-conversational inputs: plain text strings
#
# This inconsistency causes reward functions to fail when they expect one format but get the other.
# The variables `prompts_text` and `completions_text` always contain plain decoded text strings.
#
# Fix: Always pass plain text (prompts_text, completions_text) to _calculate_rewards for consistency.
# This ensures reward functions receive predictable string format regardless of conversational mode.
def grpo_trainer__calculate_rewards_text_fix(function_name, function):
if function_name != "_generate_and_score_completions":
return function

# Only apply if prompts_text and completions_text exist (TRL 0.25.0+)
if "prompts_text" in function and "completions_text" in function:
# Replace the _calculate_rewards call to use text versions
function = function.replace(
"self._calculate_rewards(inputs, prompts, completions, completion_ids_list)",
"self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)",
Comment on lines +598 to +600

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep structured prompts available for reward functions

This replacement forces _calculate_rewards to always receive prompts_text/completions_text even in conversational mode. Reward functions that intentionally inspect message structure (e.g., roles, tool-call fields, or metadata in dicts) will now receive flattened strings and can no longer operate correctly. That’s a behavioral regression for existing conversational reward functions on TRL 0.25.0+.

Useful? React with 👍 / 👎.

)
Comment on lines +598 to +601
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of function.replace for patching the _calculate_rewards call is a bit fragile as it's sensitive to whitespace changes in the upstream TRL library. Using re.sub would make this patch more robust against minor formatting variations. This is also consistent with how other patches are implemented in this file.

Suggested change
function = function.replace(
"self._calculate_rewards(inputs, prompts, completions, completion_ids_list)",
"self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)",
)
function = re.sub(
r"self\._calculate_rewards\(\s*inputs,\s*prompts,\s*completions,\s*completion_ids_list\s*\)",
"self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)",
function,
)


return function


RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__calculate_rewards_text_fix)


# Fix {"reasoning_effort" : "high"} not applied
def grpo_trainer_fix_maybe_apply_chat_template(function_name, function):
spaces = function.find("def ")
Expand Down