diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 647c7e5f08..80ddbf19a4 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1559,10 +1559,48 @@ def patch_trl_openenv(): return +def _patch_prepare_multimodal_messages(): + """Fix 2: TRL >= 0.25.1 calls prepare_multimodal_messages unconditionally for + vision models. When notebooks pre-apply chat templates (converting prompts to strings), + the function crashes iterating over characters. This patch adds isinstance(messages, str) + guard to return strings unchanged.""" + try: + import trl.data_utils as _du + except ImportError: + return + + _original = getattr(_du, "prepare_multimodal_messages", None) + if _original is None: + return + if getattr(_original, "_unsloth_patched", False): + return + + def _safe_prepare_multimodal_messages(messages, *args, **kwargs): + # If messages is already a string (pre-applied chat template), return as-is + if isinstance(messages, str): + return messages + return _original(messages, *args, **kwargs) + + _safe_prepare_multimodal_messages._unsloth_patched = True + _du.prepare_multimodal_messages = _safe_prepare_multimodal_messages + + # Also patch in grpo_trainer module if imported + try: + import trl.trainer.grpo_trainer as _gt + + if hasattr(_gt, "prepare_multimodal_messages"): + _gt.prepare_multimodal_messages = _safe_prepare_multimodal_messages + except ImportError: + pass + + logger.info("Unsloth: Patched prepare_multimodal_messages with string guard") + + def PatchFastRL(algorithm = None, FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) patch_trl_rl_trainers() patch_trl_openenv() + _patch_prepare_multimodal_messages() if type(algorithm) is str and algorithm.islower(): PatchRLStatistics(algorithm) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ce83396960..a5039088cf 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -577,6 +577,35 @@ def grpo_trainer__generate_and_score_completions(function_name, function): RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__generate_and_score_completions) +# Fix 6: TRL 0.25.0+ _calculate_rewards text arguments +# +# TRL 0.25.0+ passes `prompts` and `completions` to _calculate_rewards in different formats: +# - For conversational inputs: list of dicts [{"role": "assistant", "content": "..."}] +# - For non-conversational inputs: plain text strings +# +# This inconsistency causes reward functions to fail when they expect one format but get the other. +# The variables `prompts_text` and `completions_text` always contain plain decoded text strings. +# +# Fix: Always pass plain text (prompts_text, completions_text) to _calculate_rewards for consistency. +# This ensures reward functions receive predictable string format regardless of conversational mode. +def grpo_trainer__calculate_rewards_text_fix(function_name, function): + if function_name != "_generate_and_score_completions": + return function + + # Only apply if prompts_text and completions_text exist (TRL 0.25.0+) + if "prompts_text" in function and "completions_text" in function: + # Replace the _calculate_rewards call to use text versions + function = function.replace( + "self._calculate_rewards(inputs, prompts, completions, completion_ids_list)", + "self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)", + ) + + return function + + +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__calculate_rewards_text_fix) + + # Fix {"reasoning_effort" : "high"} not applied def grpo_trainer_fix_maybe_apply_chat_template(function_name, function): spaces = function.find("def ")