Skip to content

Commit 56379ef

Browse files
fix: filter mbpp+ overlap instead of aborting; mbpp+ spans multiple hf splits
1 parent 7b4b610 commit 56379ef

1 file changed

Lines changed: 31 additions & 16 deletions

File tree

scripts/build_sft_dataset.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,24 +93,39 @@ def _build_prompt(text: str, test_list: list[str]) -> str:
9393

9494

9595
def load_mbpp_train() -> list[Task]:
96-
"""Return the MBPP train split as `Task` objects (skips ones we can't parse)."""
97-
ds = cast("Any", load_dataset("mbpp", split="train"))
96+
"""Return MBPP tasks disjoint from MBPP+ (our held-out eval set).
97+
98+
EvalPlus's MBPP+ draws task_ids from across the whole of MBPP, not just
99+
the "test" split — HuggingFace's `mbpp[train]` alone overlaps MBPP+ by
100+
~107 tasks. Strategy: load train + validation + prompt (skip HF's
101+
"test" since MBPP+ is sourced from it), then filter out anything whose
102+
task_id appears in MBPP+.
103+
"""
104+
plus_ids: set[int] = {
105+
int(t.task_id.split("/")[-1]) for t in load_mbpp_plus()
106+
}
107+
98108
tasks: list[Task] = []
99-
for item in ds:
100-
item_d = cast("dict[str, Any]", item)
101-
test_list: list[str] = list(item_d.get("test_list") or [])
102-
entry_point = _infer_entry_point(test_list)
103-
if not entry_point or not test_list:
104-
continue
105-
tasks.append(
106-
Task(
107-
task_id=f"Mbpp/{item_d['task_id']}",
108-
prompt=_build_prompt(str(item_d["text"]), test_list),
109-
canonical_solution=str(item_d["code"]),
110-
test="\n".join(test_list),
111-
entry_point=entry_point,
109+
for split in ("train", "validation", "prompt"):
110+
ds = cast("Any", load_dataset("mbpp", split=split))
111+
for item in ds:
112+
item_d = cast("dict[str, Any]", item)
113+
task_id_int = int(item_d["task_id"])
114+
if task_id_int in plus_ids:
115+
continue # would leak into eval — skip
116+
test_list: list[str] = list(item_d.get("test_list") or [])
117+
entry_point = _infer_entry_point(test_list)
118+
if not entry_point or not test_list:
119+
continue
120+
tasks.append(
121+
Task(
122+
task_id=f"Mbpp/{task_id_int}",
123+
prompt=_build_prompt(str(item_d["text"]), test_list),
124+
canonical_solution=str(item_d["code"]),
125+
test="\n".join(test_list),
126+
entry_point=entry_point,
127+
)
112128
)
113-
)
114129
return tasks
115130

116131

0 commit comments

Comments
 (0)