44from __future__ import annotations
55
66from dataclasses import dataclass
7+ import re
78
8- from marin .rl .math_utils import grade_answer , last_boxed_only_string , normalize_answer
9+ from marin .rl .environments .tinker_environments .math_grading import extract_boxed , grade_answer , normalize_answer
10+ from marin .rl .math_utils import last_boxed_only_string
911from marin .test_time_scaling .config import ScoringMode
1012
1113
@@ -18,6 +20,22 @@ class CandidateScore:
1820 is_correct : bool | None
1921
2022
23+ _SIMPLE_FRAC_PATTERN = re .compile (r"^\\(?:dfrac|tfrac|frac)\{([^{}]+)\}\{([^{}]+)\}$" )
24+
25+
26+ def _normalize_extracted_answer (answer : str ) -> str :
27+ normalized = normalize_answer (answer )
28+ if normalized is None :
29+ return answer
30+
31+ match = _SIMPLE_FRAC_PATTERN .fullmatch (normalized )
32+ if match is None :
33+ return normalized
34+
35+ numerator , denominator = match .groups ()
36+ return f"{ numerator } /{ denominator } "
37+
38+
2139def score_candidate_text (text : str , expected_answer : str | None , scoring_mode : ScoringMode ) -> CandidateScore :
2240 """Score a generated candidate against the prompt's configured scoring mode."""
2341
@@ -34,6 +52,7 @@ def score_candidate_text(text: str, expected_answer: str | None, scoring_mode: S
3452 if boxed is None :
3553 return CandidateScore (extracted_answer = None , parse_valid = False , is_correct = False if expected_answer else None )
3654
37- extracted_answer = normalize_answer (boxed )
38- is_correct = grade_answer (boxed , expected_answer ) if expected_answer is not None else None
55+ extracted_answer_raw = extract_boxed (boxed )
56+ extracted_answer = _normalize_extracted_answer (extracted_answer_raw )
57+ is_correct = grade_answer (extracted_answer_raw , expected_answer ) if expected_answer is not None else None
3958 return CandidateScore (extracted_answer = extracted_answer , parse_valid = True , is_correct = is_correct )
0 commit comments