20
20
import math
21
21
import re
22
22
from functools import partial , update_wrapper
23
- from typing import Callable , Dict
23
+ from typing import Callable , Dict , Optional
24
24
25
25
from latex2sympy2_extended import NormalizationConfig
26
26
from math_verify import LatexExtractionConfig , parse , verify
38
38
AsyncSandbox = None
39
39
40
40
41
- def accuracy_reward (completions , solution , ** kwargs ):
41
+ def accuracy_reward (completions : list [ list [ dict [ str , str ]]], solution : list [ str ] , ** kwargs ) -> list [ Optional [ float ]] :
42
42
"""Reward function that checks if the completion is the same as the ground truth."""
43
43
contents = [completion [0 ]["content" ] for completion in completions ]
44
44
rewards = []
45
45
for content , sol in zip (contents , solution ):
46
46
gold_parsed = parse (
47
47
sol ,
48
48
extraction_mode = "first_match" ,
49
- extraction_config = [LatexExtractionConfig ()],
50
49
)
51
50
if len (gold_parsed ) != 0 :
52
51
# We require the answer to be provided in correct latex (no malformed operators)
@@ -69,15 +68,15 @@ def accuracy_reward(completions, solution, **kwargs):
69
68
],
70
69
extraction_mode = "first_match" ,
71
70
)
72
- # Reward 1 if the content is the same as the ground truth, 0 otherwise
71
+ # Compute binary rewards if verifiable, `None` otherwise to skip this example
73
72
try :
74
- reward = float (verify (answer_parsed , gold_parsed ))
73
+ reward = float (verify (gold_parsed , answer_parsed ))
75
74
except Exception as e :
76
75
print (f"verify failed: { e } , answer: { answer_parsed } , gold: { gold_parsed } " )
77
- reward = 0.0
76
+ reward = None
78
77
else :
79
- # If the gold solution is not parseable, we reward 1 to skip this example
80
- reward = 1.0
78
+ # If the gold solution is not parseable, we assign `None` to skip this example
79
+ reward = None
81
80
print ("Failed to parse gold solution: " , sol )
82
81
rewards .append (reward )
83
82
0 commit comments