Skip to content

Commit 9a0c247

Browse files
committed
[tts] Restore math scorer canonicalization on main
1 parent 89aac6f commit 9a0c247

1 file changed

Lines changed: 22 additions & 3 deletions

File tree

lib/marin/src/marin/test_time_scaling/scorers.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from __future__ import annotations
55

66
from dataclasses import dataclass
7+
import re
78

8-
from marin.rl.math_utils import grade_answer, last_boxed_only_string, normalize_answer
9+
from marin.rl.environments.tinker_environments.math_grading import extract_boxed, grade_answer, normalize_answer
10+
from marin.rl.math_utils import last_boxed_only_string
911
from marin.test_time_scaling.config import ScoringMode
1012

1113

@@ -18,6 +20,22 @@ class CandidateScore:
1820
is_correct: bool | None
1921

2022

23+
_SIMPLE_FRAC_PATTERN = re.compile(r"^\\(?:dfrac|tfrac|frac)\{([^{}]+)\}\{([^{}]+)\}$")
24+
25+
26+
def _normalize_extracted_answer(answer: str) -> str:
27+
normalized = normalize_answer(answer)
28+
if normalized is None:
29+
return answer
30+
31+
match = _SIMPLE_FRAC_PATTERN.fullmatch(normalized)
32+
if match is None:
33+
return normalized
34+
35+
numerator, denominator = match.groups()
36+
return f"{numerator}/{denominator}"
37+
38+
2139
def score_candidate_text(text: str, expected_answer: str | None, scoring_mode: ScoringMode) -> CandidateScore:
2240
"""Score a generated candidate against the prompt's configured scoring mode."""
2341

@@ -34,6 +52,7 @@ def score_candidate_text(text: str, expected_answer: str | None, scoring_mode: S
3452
if boxed is None:
3553
return CandidateScore(extracted_answer=None, parse_valid=False, is_correct=False if expected_answer else None)
3654

37-
extracted_answer = normalize_answer(boxed)
38-
is_correct = grade_answer(boxed, expected_answer) if expected_answer is not None else None
55+
extracted_answer_raw = extract_boxed(boxed)
56+
extracted_answer = _normalize_extracted_answer(extracted_answer_raw)
57+
is_correct = grade_answer(extracted_answer_raw, expected_answer) if expected_answer is not None else None
3958
return CandidateScore(extracted_answer=extracted_answer, parse_valid=True, is_correct=is_correct)

0 commit comments

Comments
 (0)