Skip to content

[GRPO] generate with prompt containing the first <think> tag #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,12 @@ def main(script_args, training_args, model_args):

# Format into conversation
def make_conversation(example):
# start the assistant with a <think> tag
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
{"role": "assistant", "content": "Let me solve this step by step.\n<think>"},
],
}

Expand Down
2 changes: 1 addition & 1 deletion src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def accuracy_reward(completions, solution, **kwargs):

def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
pattern = r"^.+(?:<think>.*?</think>\s*)?<answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
Expand Down
20 changes: 16 additions & 4 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ def test_accuracy_reward_wrong_answer(self):

def test_format_reward_correct(self):
"""Test format_reward with correct format."""
completion = [[{"content": "<think>Some reasoning</think><answer>The answer</answer>"}]]
rewards = format_reward(completion)
self.assertEqual(rewards[0], 1.0)
formats = [
"<think>Some reasoning</think><answer>The answer</answer>",
"Some reasoning</think><answer>The answer</answer>",
"<think><think>Some reasoning</think><answer>The answer</answer>",
]
for fmt in formats:
completion = [[{"content": fmt}]]
rewards = format_reward(completion)
self.assertEqual(rewards[0], 1.0, msg=f"Expected format reward of 1.0 for {fmt}")

def test_format_reward_incorrect(self):
"""Test format_reward with incorrect format."""
Expand All @@ -45,7 +51,7 @@ def test_format_reward_incorrect(self):
for fmt in incorrect_formats:
completion = [[{"content": fmt}]]
rewards = format_reward(completion)
self.assertEqual(rewards[0], 0.0)
self.assertEqual(rewards[0], 0.0, msg=f"Expected format reward of 0.0 for {fmt}")

def test_reasoning_steps_reward(self):
"""Test reasoning_steps_reward with various formats."""
Expand Down Expand Up @@ -118,6 +124,12 @@ def test_positive_max_penalty_raises_value_error(self):
with self.assertRaisesRegex(ValueError, "max_penalty 1.5 should not be positive"):
get_repetition_penalty_reward(ngram_size=2, max_penalty=1.5)

def test_zero_max_penalty_returns_zero(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=0.0)
completions = [[{"content": "this is a test sentence"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])

def test_no_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
completions = [[{"content": "this is a test sentence"}]]
Expand Down