Skip to content

Add prefix-preserving training chat template for GPT-OSS#5109

Open
qgallouedec wants to merge 23 commits intomainfrom
support-gpt-oss
Open

Add prefix-preserving training chat template for GPT-OSS#5109
qgallouedec wants to merge 23 commits intomainfrom
support-gpt-oss

Conversation

@qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Feb 17, 2026

PR Description

This PR adds support for tool-calling training with GPT-OSS models (e.g., gpt-oss-20b) in the GRPO agent training pipeline, extending the existing Qwen3 support.

Problem

For context, the main challenge is to ensure that the chat template used for training is prefix-preserving, meaning that when new messages are appended to a conversation, the previous sequence of tokens remains unchanged. This is crucial for multi-turn training.

The original GPT-OSS chat template is not prefix-preserving for two reasons:

  • <|return|> vs <|end|>: The last assistant message ends with <|return|> while all other turns use <|end|>. When a conversation is extended with new turns, the previously-final assistant message switches from <|return|> to <|end|>, breaking the prefix.

    >>> from transformers import AutoTokenizer
    >>> tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
    >>> messages = [
    ...     {"role": "user", "content": "What is 2+2?."},
    ...     {"role": "assistant", "content": "4"},
    ...     {"role": "user", "content": "And what about 3+3?"},
    ... ]
    >>> tokenizer.apply_chat_template(messages[:2], tokenize=False)
    '<|start|>system<|message|>You are ChatGPT[...].<|end|><|start|>user<|message|>What is 2+2?.<|end|><|start|>assistant<|channel|>final<|message|>4<|return|>'
    >>> tokenizer.apply_chat_template(messages, tokenize=False)
    '<|start|>system<|message|>You are ChatGPT[...].<|end|><|start|>user<|message|>What is 2+2?.<|end|><|start|>assistant<|channel|>final<|message|>4<|end|><|start|>user<|message|>And what about 3+3?<|end|>'
  • Conditional thinking blocks: (Same as Qwen3) The thinking field is only rendered for the final assistant turn (via loop.last in the template), so earlier assistant turns lose their thinking content when new messages are appended.

    >>> from transformers import AutoTokenizer
    >>> tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
    >>> messages = [
    ...     {"role": "user", "content": "What is 2+2?."},
    ...     {"role": "assistant", "thinking": "🤔", "content": "4"},
    ...     {"role": "user", "content": "And what about 3+3?"},
    ... ]
    >>> tokenizer.apply_chat_template(messages[:2], tokenize=False)
    '<|start|>system<|message|>You are ChatGPT[...]<|start|>assistant<|channel|>analysis<|message|>🤔<|end|><|start|>assistant<|channel|>final<|message|>4<|return|>'
    >>> tokenizer.apply_chat_template(messages, tokenize=False)
    '<|start|>system<|message|>You are ChatGPT[...]<|start|>assistant<|channel|>final<|message|>4<|end|><|start|>user<|message|>And what about 3+3?<|end|>'

This PR introduces

  • gpt_oss_chat_template — A reference copy of the original GPT-OSS chat template stored in chat_template_utils.py for template matching.
  • gpt_oss_training_chat_template — A modified training-safe template with two key changes:
    • Replaces <|return|> with <|end|> on the final assistant message to ensure consistent turn delimiters across all turns.
    • Changes loop.last to true so thinking blocks are always rendered, not just on the final turn.
  • Updated get_training_chat_template() — Now recognizes GPT-OSS templates and returns the training variant, alongside the existing Qwen3 support.
  • Updated is_chat_template_prefix_preserving() — Test messages now include both reasoning_content and thinking keys, since GPT-OSS uses thinking while Qwen3 uses reasoning_content.
  • Extended tests — All TestGetTrainingChatTemplate tests are now parameterized over both GPT-OSS and Qwen3, with a helper _assert_equal that accounts for the expected <|return|><|end|> difference.

Design notes (happy to get thoughts on this!)

The consequence of this is that the <|return|> token is not seen during training. My intuition is that, since GRPO uses a comparative objective (relative ranking within groups), the model's ability to produce <|return|> at inference is not expected to degrade.


Note

Medium Risk
Changes chat-template rendering used for training (including tool-call and termination-token handling) which can affect prompt formatting and GRPO learning behavior. Scope is limited to template selection/validation and is covered by expanded unit tests.

Overview
Adds GPT-OSS support to get_training_chat_template() by introducing gpt_oss_chat_template matching and a new gpt_oss_training_chat_template that enforces prefix preservation (consistent end delimiters and always-rendered thinking/analysis, including for tool-call turns).

Strengthens is_chat_template_prefix_preserving() to detect non-prefix-preserving behavior across both reasoning fields (reasoning_content vs thinking) and tool-calling sequences, and expands unit tests to cover GPT-OSS with expected <|return|><|end|> differences. Updates GRPO docs to list GPT-OSS as a tested model.

Written by Cursor Bugbot for commit 0d0edf8. This will update automatically on new commits. Configure here.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the title Support tool-calling training for GPT-OSS Add prefix-preserving training chat template for GPT-OSS Feb 17, 2026
Base automatically changed from more-test-get_training_chat_template to main February 18, 2026 12:45
Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix.

I think:

  • The prefix-preserving fix is good and likely necessary.
  • But I would say the Design note’s justification is a weak point: I think comparative objective doesn't guarantee preservation of an unseen special token; this could also impact stopping and parsing.

I would prefer others to comment on this.

@qgallouedec
Copy link
Member Author

But I would say the Design note’s justification is a weak point: I think comparative objective doesn't guarantee preservation of an unseen special token; this could also impact stopping and parsing.

100% agree. At this point it's an intuition. I'll update the comment to be more conservative.

@qgallouedec
Copy link
Member Author

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7416e5fae4

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +1041 to +1044
{%- elif message.content and not future_final_message.found %}
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }}
{%- elif message.thinking and not future_final_message.found %}
{{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep tool-call analysis stable across future assistant turns

The training GPT-OSS template still drops analysis text from an assistant tool-call turn when a later assistant final turn exists, because analysis is emitted only when not future_final_message.found. In a valid tool flow (e.g., user → assistant tool call with content/thinking → tool → assistant final), formatting the shorter prefix includes analysis but formatting the extended conversation removes it, so earlier tokens change and the template is not prefix-preserving for tool-calling traces.

Useful? React with 👍 / 👎.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

messages2 = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
{"role": "assistant", "reasoning_content": "Hmmm", "thinking": "Hmmm", "content": "It is blue."},
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added reasoning_content breaks Qwen3.5 prefix preservation check

High Severity

Adding reasoning_content and thinking with value "Hmmm" to messages2 causes the prefix check to fail for the Qwen3.5 training template. The Qwen3.5 generation prompt emits <think>\n\n</think>\n\n (empty thinking) by default, but the assistant turn in messages2 now renders <think>\nHmmm\n</think>\n\n (non-empty thinking). Since text2 no longer starts with text1, is_chat_template_prefix_preserving returns False for the Qwen3.5 training template, causing test_new_chat_template_is_prefix_preserving to fail for Qwen3.5.

Additional Locations (1)

Fix in Cursor Fix in Web

{%- if future_message.role == 'assistant' and "tool_calls" not in future_message %}
{%- set future_final_message.found = true %}
{%- endif %}
{%- endfor %}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead future_final_message computation in training template

Low Severity

In gpt_oss_training_chat_template, the future_final_message namespace is still computed (iterating over remaining messages) inside the tool_calls branch, but it's never read — the not future_final_message.found guards were intentionally removed from the elif conditions on lines 1261 and 1263. This dead loop adds unnecessary iteration in the Jinja2 template and may confuse maintainers into thinking the variable is still used.

Additional Locations (1)

Fix in Cursor Fix in Web

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI is red.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants