@@ -93,24 +93,39 @@ def _build_prompt(text: str, test_list: list[str]) -> str:
9393
9494
9595def load_mbpp_train () -> list [Task ]:
96- """Return the MBPP train split as `Task` objects (skips ones we can't parse)."""
97- ds = cast ("Any" , load_dataset ("mbpp" , split = "train" ))
96+ """Return MBPP tasks disjoint from MBPP+ (our held-out eval set).
97+
98+ EvalPlus's MBPP+ draws task_ids from across the whole of MBPP, not just
99+ the "test" split — HuggingFace's `mbpp[train]` alone overlaps MBPP+ by
100+ ~107 tasks. Strategy: load train + validation + prompt (skip HF's
101+ "test" since MBPP+ is sourced from it), then filter out anything whose
102+ task_id appears in MBPP+.
103+ """
104+ plus_ids : set [int ] = {
105+ int (t .task_id .split ("/" )[- 1 ]) for t in load_mbpp_plus ()
106+ }
107+
98108 tasks : list [Task ] = []
99- for item in ds :
100- item_d = cast ("dict[str, Any]" , item )
101- test_list : list [str ] = list (item_d .get ("test_list" ) or [])
102- entry_point = _infer_entry_point (test_list )
103- if not entry_point or not test_list :
104- continue
105- tasks .append (
106- Task (
107- task_id = f"Mbpp/{ item_d ['task_id' ]} " ,
108- prompt = _build_prompt (str (item_d ["text" ]), test_list ),
109- canonical_solution = str (item_d ["code" ]),
110- test = "\n " .join (test_list ),
111- entry_point = entry_point ,
109+ for split in ("train" , "validation" , "prompt" ):
110+ ds = cast ("Any" , load_dataset ("mbpp" , split = split ))
111+ for item in ds :
112+ item_d = cast ("dict[str, Any]" , item )
113+ task_id_int = int (item_d ["task_id" ])
114+ if task_id_int in plus_ids :
115+ continue # would leak into eval — skip
116+ test_list : list [str ] = list (item_d .get ("test_list" ) or [])
117+ entry_point = _infer_entry_point (test_list )
118+ if not entry_point or not test_list :
119+ continue
120+ tasks .append (
121+ Task (
122+ task_id = f"Mbpp/{ task_id_int } " ,
123+ prompt = _build_prompt (str (item_d ["text" ]), test_list ),
124+ canonical_solution = str (item_d ["code" ]),
125+ test = "\n " .join (test_list ),
126+ entry_point = entry_point ,
127+ )
112128 )
113- )
114129 return tasks
115130
116131
0 commit comments