Skip to content

Commit 4f5b21e

Browse files
authored
Fix accuracy reward for math (#566)
* Fix accuracy reward for math * Add typing * Add unit test * Return None for invalid samples * Fix order of answers * Fix type * Use None for non-verifiable answers
1 parent 9915e06 commit 4f5b21e

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

src/open_r1/rewards.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import math
2121
import re
2222
from functools import partial, update_wrapper
23-
from typing import Callable, Dict
23+
from typing import Callable, Dict, Optional
2424

2525
from latex2sympy2_extended import NormalizationConfig
2626
from math_verify import LatexExtractionConfig, parse, verify
@@ -38,15 +38,14 @@
3838
AsyncSandbox = None
3939

4040

41-
def accuracy_reward(completions, solution, **kwargs):
41+
def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
4242
"""Reward function that checks if the completion is the same as the ground truth."""
4343
contents = [completion[0]["content"] for completion in completions]
4444
rewards = []
4545
for content, sol in zip(contents, solution):
4646
gold_parsed = parse(
4747
sol,
4848
extraction_mode="first_match",
49-
extraction_config=[LatexExtractionConfig()],
5049
)
5150
if len(gold_parsed) != 0:
5251
# We require the answer to be provided in correct latex (no malformed operators)
@@ -69,15 +68,15 @@ def accuracy_reward(completions, solution, **kwargs):
6968
],
7069
extraction_mode="first_match",
7170
)
72-
# Reward 1 if the content is the same as the ground truth, 0 otherwise
71+
# Compute binary rewards if verifiable, `None` otherwise to skip this example
7372
try:
74-
reward = float(verify(answer_parsed, gold_parsed))
73+
reward = float(verify(gold_parsed, answer_parsed))
7574
except Exception as e:
7675
print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
77-
reward = 0.0
76+
reward = None
7877
else:
79-
# If the gold solution is not parseable, we reward 1 to skip this example
80-
reward = 1.0
78+
# If the gold solution is not parseable, we assign `None` to skip this example
79+
reward = None
8180
print("Failed to parse gold solution: ", sol)
8281
rewards.append(reward)
8382

tests/test_rewards.py

+6
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,13 @@ def test_accuracy_reward_wrong_answer(self):
8282
"""Test accuracy_reward with an incorrect answer."""
8383
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
8484
solution = [r"\frac{63}{400}"]
85+
rewards = accuracy_reward(completion, solution)
86+
self.assertEqual(rewards[0], 0.0)
8587

88+
def test_accuracy_reward_wrong_answer_no_latex(self):
89+
"""Test accuracy_reward with an incorrect answer and gold solution with no latex."""
90+
completion = [[{"content": r"\boxed{3}"}]]
91+
solution = ["6"]
8692
rewards = accuracy_reward(completion, solution)
8793
self.assertEqual(rewards[0], 0.0)
8894

0 commit comments

Comments
 (0)