|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +"""HumanEval benchmark. |
| 3 | +
|
| 4 | +Tests code generation ability using function completion problems. |
| 5 | +Model receives a function signature + docstring and must complete the body. |
| 6 | +Verification: generated code + unit tests run in sandboxed subprocess. |
| 7 | +Dataset bundled from openai/openai_humaneval on HuggingFace (164 problems). |
| 8 | +
|
| 9 | +SECURITY NOTE: This benchmark executes model-generated code on the local |
| 10 | +machine. Mitigations: subprocess with timeout, memory limits, temp file cleanup. |
| 11 | +""" |
| 12 | + |
| 13 | +import asyncio |
| 14 | +import json |
| 15 | +import logging |
| 16 | +import os |
| 17 | +import re |
| 18 | +import resource |
| 19 | +import subprocess |
| 20 | +import tempfile |
| 21 | +import time |
| 22 | +from pathlib import Path |
| 23 | +from typing import Any, Callable, Optional |
| 24 | + |
| 25 | +from .base import BaseBenchmark, BenchmarkResult, QuestionResult |
| 26 | +from .datasets import deterministic_sample, load_jsonl |
| 27 | + |
| 28 | +logger = logging.getLogger(__name__) |
| 29 | + |
| 30 | +DATA_DIR = Path(__file__).parent / "data" |
| 31 | + |
| 32 | +EXEC_TIMEOUT_SECONDS = 15 |
| 33 | +EXEC_MEMORY_LIMIT_BYTES = 256 * 1024 * 1024 # 256 MB |
| 34 | + |
| 35 | + |
| 36 | +def _extract_code(response: str, prompt: str) -> str: |
| 37 | + """Extract the function body from model response. |
| 38 | +
|
| 39 | + The model may return the full function (including signature) or just the body. |
| 40 | + We need to combine it with the original prompt to form a complete function. |
| 41 | + """ |
| 42 | + response = response.strip() |
| 43 | + |
| 44 | + # If response contains a code block, extract it |
| 45 | + match = re.search(r"```python\s*\n(.*?)```", response, re.DOTALL) |
| 46 | + if match: |
| 47 | + code = match.group(1).strip() |
| 48 | + # If the code block contains the function def, use it standalone |
| 49 | + if "def " in code: |
| 50 | + return code |
| 51 | + # Otherwise it's just the body, combine with prompt |
| 52 | + return prompt + code |
| 53 | + |
| 54 | + match = re.search(r"```\s*\n(.*?)```", response, re.DOTALL) |
| 55 | + if match: |
| 56 | + code = match.group(1).strip() |
| 57 | + if "def " in code: |
| 58 | + return code |
| 59 | + return prompt + code |
| 60 | + |
| 61 | + # No code block — response is the continuation of the prompt |
| 62 | + # Check if response starts with the function def (model repeated the signature) |
| 63 | + if response.startswith("def ") or response.startswith("from ") or response.startswith("import "): |
| 64 | + return response |
| 65 | + |
| 66 | + # Response is just the function body — combine with prompt |
| 67 | + return prompt + response |
| 68 | + |
| 69 | + |
| 70 | +def _set_resource_limits(): |
| 71 | + """Set resource limits for subprocess.""" |
| 72 | + try: |
| 73 | + resource.setrlimit(resource.RLIMIT_AS, (EXEC_MEMORY_LIMIT_BYTES, EXEC_MEMORY_LIMIT_BYTES)) |
| 74 | + except (ValueError, resource.error): |
| 75 | + pass |
| 76 | + try: |
| 77 | + resource.setrlimit(resource.RLIMIT_CPU, (EXEC_TIMEOUT_SECONDS + 5, EXEC_TIMEOUT_SECONDS + 5)) |
| 78 | + except (ValueError, resource.error): |
| 79 | + pass |
| 80 | + |
| 81 | + |
| 82 | +def _execute_with_tests(code: str, test_code: str, entry_point: str) -> tuple[bool, str]: |
| 83 | + """Execute generated code with test cases. |
| 84 | +
|
| 85 | + Combines the generated function with test assertions and runs in subprocess. |
| 86 | +
|
| 87 | + Returns: |
| 88 | + (passed, error_message) |
| 89 | + """ |
| 90 | + # Build the complete test script |
| 91 | + script = f"""{code} |
| 92 | +
|
| 93 | +{test_code} |
| 94 | +
|
| 95 | +check({entry_point}) |
| 96 | +""" |
| 97 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: |
| 98 | + f.write(script) |
| 99 | + tmp_path = f.name |
| 100 | + |
| 101 | + try: |
| 102 | + result = subprocess.run( |
| 103 | + ["python3", tmp_path], |
| 104 | + capture_output=True, |
| 105 | + text=True, |
| 106 | + timeout=EXEC_TIMEOUT_SECONDS, |
| 107 | + preexec_fn=_set_resource_limits, |
| 108 | + env={ |
| 109 | + "PATH": os.environ.get("PATH", "/usr/bin:/usr/local/bin"), |
| 110 | + "HOME": os.environ.get("HOME", "/tmp"), |
| 111 | + "LANG": "en_US.UTF-8", |
| 112 | + }, |
| 113 | + ) |
| 114 | + if result.returncode == 0: |
| 115 | + return True, "" |
| 116 | + else: |
| 117 | + return False, result.stderr[:500] |
| 118 | + except subprocess.TimeoutExpired: |
| 119 | + return False, "Execution timed out" |
| 120 | + except Exception as e: |
| 121 | + return False, str(e)[:500] |
| 122 | + finally: |
| 123 | + try: |
| 124 | + os.unlink(tmp_path) |
| 125 | + except OSError: |
| 126 | + pass |
| 127 | + |
| 128 | + |
| 129 | +class HumanEvalBenchmark(BaseBenchmark): |
| 130 | + """HumanEval: function completion with unit test verification.""" |
| 131 | + |
| 132 | + name = "humaneval" |
| 133 | + quick_size = 100 |
| 134 | + |
| 135 | + async def load_dataset(self, sample_size: int = 0) -> list[dict]: |
| 136 | + """Load HumanEval from bundled data.""" |
| 137 | + items = load_jsonl(DATA_DIR / "humaneval.jsonl") |
| 138 | + |
| 139 | + normalized = [] |
| 140 | + for item in items: |
| 141 | + normalized.append({ |
| 142 | + "id": item["task_id"], |
| 143 | + "prompt": item["prompt"], |
| 144 | + "test": item["test"], |
| 145 | + "entry_point": item["entry_point"], |
| 146 | + "question": item["prompt"], # for get_question_text |
| 147 | + }) |
| 148 | + |
| 149 | + logger.info(f"HumanEval: loaded {len(normalized)} problems") |
| 150 | + |
| 151 | + if sample_size == 0: |
| 152 | + return normalized |
| 153 | + |
| 154 | + return deterministic_sample(normalized, sample_size) |
| 155 | + |
| 156 | + def get_max_tokens(self) -> int: |
| 157 | + return 512 |
| 158 | + |
| 159 | + def format_prompt(self, item: dict) -> list[dict[str, str]]: |
| 160 | + """Format as a function completion prompt.""" |
| 161 | + prompt = item["prompt"] |
| 162 | + content = ( |
| 163 | + "Complete the following Python function. " |
| 164 | + "Provide only the complete function implementation, no explanations.\n\n" |
| 165 | + f"{prompt}" |
| 166 | + ) |
| 167 | + return [{"role": "user", "content": content}] |
| 168 | + |
| 169 | + def extract_answer(self, response: str, item: dict) -> str: |
| 170 | + """Extract the complete function from model response.""" |
| 171 | + return _extract_code(response, item["prompt"]) |
| 172 | + |
| 173 | + def check_answer(self, predicted: str, item: dict) -> bool: |
| 174 | + """Execute the generated code with test cases.""" |
| 175 | + if not predicted.strip(): |
| 176 | + return False |
| 177 | + |
| 178 | + passed, error = _execute_with_tests( |
| 179 | + predicted, item["test"], item["entry_point"] |
| 180 | + ) |
| 181 | + return passed |
| 182 | + |
| 183 | + async def run( |
| 184 | + self, |
| 185 | + engine: Any, |
| 186 | + items: list[dict], |
| 187 | + on_progress: Optional[Callable[[int, int], Any]] = None, |
| 188 | + batch_size: int = 1, |
| 189 | + sampling_kwargs: Optional[dict] = None, |
| 190 | + ) -> BenchmarkResult: |
| 191 | + """Override run: generation is batched, code execution is sequential.""" |
| 192 | + results: list[QuestionResult] = [] |
| 193 | + correct = 0 |
| 194 | + start_time = time.time() |
| 195 | + completed = 0 |
| 196 | + |
| 197 | + for batch_start in range(0, len(items), batch_size): |
| 198 | + batch_end = min(batch_start + batch_size, len(items)) |
| 199 | + batch = items[batch_start:batch_end] |
| 200 | + batch_time = time.time() |
| 201 | + |
| 202 | + gen_tasks = [ |
| 203 | + self._eval_single(engine, item, batch_start + j, sampling_kwargs) |
| 204 | + for j, item in enumerate(batch) |
| 205 | + ] |
| 206 | + gen_results = await asyncio.gather(*gen_tasks) |
| 207 | + gen_elapsed = time.time() - batch_time |
| 208 | + |
| 209 | + for idx, item, response_text, prompt_text in sorted(gen_results, key=lambda x: x[0]): |
| 210 | + code = self.extract_answer(response_text, item) |
| 211 | + is_correct = self.check_answer(code, item) |
| 212 | + |
| 213 | + if is_correct: |
| 214 | + correct += 1 |
| 215 | + |
| 216 | + results.append( |
| 217 | + QuestionResult( |
| 218 | + question_id=str(item.get("id", idx)), |
| 219 | + correct=is_correct, |
| 220 | + expected="(unit tests)", |
| 221 | + predicted=code[:200] + "..." if len(code) > 200 else code, |
| 222 | + time_seconds=gen_elapsed / len(batch), |
| 223 | + question_text=prompt_text, |
| 224 | + raw_response=response_text, |
| 225 | + ) |
| 226 | + ) |
| 227 | + |
| 228 | + completed += len(batch) |
| 229 | + if on_progress: |
| 230 | + await on_progress(completed, len(items)) |
| 231 | + |
| 232 | + total_time = time.time() - start_time |
| 233 | + total = len(items) |
| 234 | + |
| 235 | + return BenchmarkResult( |
| 236 | + benchmark_name=self.name, |
| 237 | + accuracy=correct / total if total > 0 else 0.0, |
| 238 | + total_questions=total, |
| 239 | + correct_count=correct, |
| 240 | + time_seconds=total_time, |
| 241 | + question_results=results, |
| 242 | + ) |
0 commit comments