Skip to content

Commit 8c7c1ae

Browse files
fix the Pair problem: extract helpers, prepend at sandbox runtime, expose in prompt
1 parent 1e26631 commit 8c7c1ae

7 files changed

Lines changed: 219 additions & 28 deletions

File tree

scripts/build_grpo_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def main() -> int:
9999
"task_id": t.task_id,
100100
"test": t.test,
101101
"entry_point": t.entry_point,
102+
"helpers": t.helpers,
102103
"benchmark": BENCHMARK,
103104
},
104105
}

src/verifiable_rl_coder/benchmarks/base.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ class Task:
3232
canonical_solution: str
3333
test: str
3434
entry_point: str
35+
# Optional helpers (imports, classes, helper functions) extracted from
36+
# canonical_solution. For MBPP tasks where tests reference custom types
37+
# like `Pair`, these definitions are prepended to the model's solution
38+
# at sandbox-execution time so tests can run regardless of whether the
39+
# model copied them into its output.
40+
helpers: str = ""
3541

3642

3743
def to_sandbox_inputs(
@@ -64,15 +70,22 @@ def to_sandbox_inputs(
6470
return solution, tests
6571

6672
if benchmark == "mbpp_plus":
67-
# MBPP+'s `assertion` field is a block of `assert ...` lines that call
68-
# the entry point by name. Indent them into a test_main body.
73+
# Prepend helpers (e.g. `class Pair`) so test assertions can reference
74+
# custom types/imports without depending on whether the model copied
75+
# them into its output.
76+
if task.helpers:
77+
solution = f"{task.helpers}\n\n{completion}"
78+
else:
79+
solution = completion
80+
# `from solution import *` so the helpers (which are at solution.py's
81+
# module level) are visible to the test assertions.
6982
lines = [line for line in task.test.splitlines() if line.strip()]
7083
indented = "\n ".join(lines)
7184
tests = (
72-
f"from solution import {task.entry_point}\n"
85+
"from solution import * # noqa: F401, F403\n"
7386
"def test_main() -> None:\n"
7487
f" {indented}\n"
7588
)
76-
return completion, tests
89+
return solution, tests
7790

7891
raise ValueError(f"unknown benchmark: {benchmark!r}")

src/verifiable_rl_coder/benchmarks/mbpp_plus.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,43 @@
1818
from .base import Task
1919

2020

21+
def _extract_mbpp_helpers(canonical_solution: str, entry_point: str) -> str:
22+
"""Same helper-extraction logic as mbpp_train (kept independent to
23+
avoid an import cycle between mbpp_plus and mbpp_train)."""
24+
import re
25+
pattern = re.compile(
26+
rf"^def\s+{re.escape(entry_point)}\s*\(",
27+
re.MULTILINE,
28+
)
29+
match = pattern.search(canonical_solution)
30+
if not match:
31+
return ""
32+
return canonical_solution[: match.start()].rstrip()
33+
34+
2135
def load_mbpp_plus() -> list[Task]:
2236
"""Return all MBPP+ tasks in deterministic task_id order.
2337
2438
MBPP+ stores tests under `assertion` (raw `assert ...` lines referencing
2539
the entry_point by name), not `test` (HumanEval+'s `def check(candidate)`
26-
wrapper). We load the assertion block verbatim into `Task.test`; the
27-
sandbox executor in Task 5 will wrap it appropriately per benchmark.
40+
wrapper). We load the assertion block verbatim into `Task.test`. We also
41+
extract any preamble (imports, classes) from `canonical_solution` into
42+
`Task.helpers` so the sandbox executor can prepend them at run time.
2843
"""
2944
raw = cast("dict[str, dict[str, Any]]", get_mbpp_plus())
30-
tasks = [
31-
Task(
32-
task_id=task_id,
33-
prompt=item["prompt"],
34-
canonical_solution=item["canonical_solution"],
35-
test=item["assertion"],
36-
entry_point=item["entry_point"],
45+
tasks: list[Task] = []
46+
for task_id, item in raw.items():
47+
canonical = item["canonical_solution"]
48+
entry_point = item["entry_point"]
49+
tasks.append(
50+
Task(
51+
task_id=task_id,
52+
prompt=item["prompt"],
53+
canonical_solution=canonical,
54+
test=item["assertion"],
55+
entry_point=entry_point,
56+
helpers=_extract_mbpp_helpers(canonical, entry_point),
57+
)
3758
)
38-
for task_id, item in raw.items()
39-
]
4059
tasks.sort(key=lambda t: int(t.task_id.split("/")[-1]))
4160
return tasks

src/verifiable_rl_coder/benchmarks/mbpp_train.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,48 @@ def infer_entry_point(test_list: list[str]) -> str | None:
3131
return None
3232

3333

34-
def build_mbpp_prompt(text: str, test_list: list[str]) -> str:
35-
"""Compose the user-turn prompt for an MBPP-style task."""
36-
example = test_list[0].strip() if test_list else ""
37-
return (
38-
f"Task: {text.strip()}\n"
39-
"\n"
40-
"Your function must satisfy this example:\n"
41-
f" {example}\n"
42-
"\n"
43-
"Write the complete Python function."
34+
def extract_helpers(canonical_solution: str, entry_point: str) -> str:
35+
"""Return everything before the entry_point's `def` line.
36+
37+
Captures imports, helper classes (e.g. `Pair`), and helper functions
38+
that the canonical solution defines alongside the target function.
39+
Returns "" if the entry_point's def isn't found, or the canonical
40+
solution starts with the function (no preamble).
41+
"""
42+
pattern = re.compile(
43+
rf"^def\s+{re.escape(entry_point)}\s*\(",
44+
re.MULTILINE,
4445
)
46+
match = pattern.search(canonical_solution)
47+
if not match:
48+
return ""
49+
return canonical_solution[: match.start()].rstrip()
50+
51+
52+
def build_mbpp_prompt(text: str, test_list: list[str], helpers: str = "") -> str:
53+
"""Compose the user-turn prompt for an MBPP-style task.
54+
55+
If `helpers` is non-empty (i.e. the task's canonical solution defines
56+
a class/import the tests rely on), include it in the prompt so the
57+
model knows what types/symbols it can use.
58+
"""
59+
example = test_list[0].strip() if test_list else ""
60+
parts: list[str] = [f"Task: {text.strip()}", ""]
61+
if helpers:
62+
parts.extend([
63+
"Supporting definitions (already available — do not redefine):",
64+
"```python",
65+
helpers,
66+
"```",
67+
"",
68+
])
69+
parts.extend([
70+
"Your function must satisfy this example:",
71+
f" {example}",
72+
"",
73+
"Write the complete Python function.",
74+
])
75+
return "\n".join(parts)
4576

4677

4778
def load_mbpp_train() -> list[Task]:
@@ -62,13 +93,16 @@ def load_mbpp_train() -> list[Task]:
6293
entry_point = infer_entry_point(test_list)
6394
if not entry_point or not test_list:
6495
continue
96+
canonical = str(item_d["code"])
97+
helpers = extract_helpers(canonical, entry_point)
6598
tasks.append(
6699
Task(
67100
task_id=f"Mbpp/{task_id_int}",
68-
prompt=build_mbpp_prompt(str(item_d["text"]), test_list),
69-
canonical_solution=str(item_d["code"]),
101+
prompt=build_mbpp_prompt(str(item_d["text"]), test_list, helpers),
102+
canonical_solution=canonical,
70103
test="\n".join(test_list),
71104
entry_point=entry_point,
105+
helpers=helpers,
72106
)
73107
)
74108
return tasks

src/verifiable_rl_coder/training/grpo_reward.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def compute_reward(
8282
canonical_solution="",
8383
test=test,
8484
entry_point=entry_point,
85+
# Helpers (e.g. `class Pair` for some MBPP tasks) get prepended to
86+
# the model's solution at sandbox-execution time.
87+
helpers=str(extra_info.get("helpers", "")),
8588
)
8689

8790
try:

tests/test_mbpp_train.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""Tests for `extract_helpers` and helper-aware MBPP prompt construction."""
2+
3+
from __future__ import annotations
4+
5+
from verifiable_rl_coder.benchmarks.mbpp_train import (
6+
build_mbpp_prompt,
7+
extract_helpers,
8+
infer_entry_point,
9+
)
10+
11+
12+
# --- entry-point inference -------------------------------------------------
13+
14+
15+
def test_infer_entry_point_simple() -> None:
16+
assert infer_entry_point(["assert add(1, 2) == 3"]) == "add"
17+
18+
19+
def test_infer_entry_point_picks_first() -> None:
20+
assert infer_entry_point(["assert max_chain([Pair(1,2)]) == 1"]) == "max_chain"
21+
22+
23+
def test_infer_entry_point_returns_none_when_no_assert() -> None:
24+
assert infer_entry_point(["# no assertion here"]) is None
25+
26+
27+
# --- helper extraction -----------------------------------------------------
28+
29+
30+
def test_extract_helpers_returns_class_definition() -> None:
31+
canonical = (
32+
"class Pair:\n"
33+
" def __init__(self, a, b):\n"
34+
" self.a = a\n"
35+
" self.b = b\n"
36+
"\n"
37+
"def max_chain_length(arr, n):\n"
38+
" return n\n"
39+
)
40+
helpers = extract_helpers(canonical, "max_chain_length")
41+
assert "class Pair" in helpers
42+
assert "def max_chain_length" not in helpers
43+
assert helpers.endswith("self.b = b")
44+
45+
46+
def test_extract_helpers_returns_imports() -> None:
47+
canonical = (
48+
"from heapq import nlargest\n"
49+
"import math\n"
50+
"\n"
51+
"def largest_n(arr, n):\n"
52+
" return nlargest(n, arr)\n"
53+
)
54+
helpers = extract_helpers(canonical, "largest_n")
55+
assert "from heapq import nlargest" in helpers
56+
assert "import math" in helpers
57+
assert "def largest_n" not in helpers
58+
59+
60+
def test_extract_helpers_returns_empty_when_no_preamble() -> None:
61+
canonical = "def add(a, b):\n return a + b\n"
62+
assert extract_helpers(canonical, "add") == ""
63+
64+
65+
def test_extract_helpers_returns_empty_when_def_not_found() -> None:
66+
canonical = " return a + b\n" # body only, no def line — like HumanEval+
67+
assert extract_helpers(canonical, "add") == ""
68+
69+
70+
# --- prompt construction ---------------------------------------------------
71+
72+
73+
def test_prompt_without_helpers_omits_supporting_section() -> None:
74+
prompt = build_mbpp_prompt("Add two numbers.", ["assert add(1, 2) == 3"])
75+
assert "Supporting definitions" not in prompt
76+
assert "Task: Add two numbers." in prompt
77+
assert "assert add(1, 2) == 3" in prompt
78+
79+
80+
def test_prompt_with_helpers_includes_supporting_section() -> None:
81+
helpers = "class Pair:\n pass"
82+
prompt = build_mbpp_prompt(
83+
"Find longest chain.",
84+
["assert max_chain([Pair(1, 2)]) == 1"],
85+
helpers=helpers,
86+
)
87+
assert "Supporting definitions" in prompt
88+
assert "class Pair" in prompt
89+
assert "do not redefine" in prompt

tests/test_sandbox_inputs.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,19 @@
88
from verifiable_rl_coder.benchmarks.base import Task, to_sandbox_inputs
99

1010

11-
def _task(test: str, entry_point: str = "add", prompt: str = "def add(a, b):\n") -> Task:
11+
def _task(
12+
test: str,
13+
entry_point: str = "add",
14+
prompt: str = "def add(a, b):\n",
15+
helpers: str = "",
16+
) -> Task:
1217
return Task(
1318
task_id="t/0",
1419
prompt=prompt,
1520
canonical_solution=" return a + b\n",
1621
test=test,
1722
entry_point=entry_point,
23+
helpers=helpers,
1824
)
1925

2026

@@ -62,3 +68,29 @@ def test_unknown_benchmark_raises() -> None:
6268
task = _task(test="")
6369
with pytest.raises(ValueError, match="unknown benchmark"):
6470
to_sandbox_inputs(task, "def add(a, b): ...", "swe_bench_lite")
71+
72+
73+
def test_mbpp_prepends_helpers_to_solution() -> None:
74+
helpers = "class Pair:\n def __init__(self, a, b):\n self.a = a\n self.b = b"
75+
task = _task(
76+
test="assert max_chain([Pair(1, 2)]) == 1\n",
77+
entry_point="max_chain",
78+
helpers=helpers,
79+
)
80+
completion = "def max_chain(arr):\n return len(arr)\n"
81+
solution, tests = to_sandbox_inputs(task, completion, "mbpp_plus")
82+
assert "class Pair" in solution
83+
assert solution.endswith(completion)
84+
# Tests must use wildcard import so Pair (defined in solution.py) is in scope.
85+
assert "from solution import *" in tests
86+
87+
88+
def test_mbpp_no_helpers_passes_completion_through() -> None:
89+
task = _task(
90+
test="assert add(1, 2) == 3\n",
91+
helpers="",
92+
)
93+
completion = "def add(a, b):\n return a + b\n"
94+
solution, _ = to_sandbox_inputs(task, completion, "mbpp_plus")
95+
assert solution == completion # unchanged when no helpers
96+

0 commit comments

Comments
 (0)