Skip to content

Commit 9915e06

Browse files
edbeechinglewtun
andauthored
Async code reward fixes (#546)
* expose num parallel code executions * add e2b benchmarking script * adds new parallel code execution with better execption handling * style * update default * increase sandbox timeout * Add pretty table and Sandbox IDs * Add Sandbox ID * fix merge --------- Co-authored-by: Lewis Tunstall <[email protected]>
1 parent 1802bec commit 9915e06

File tree

4 files changed

+152
-25
lines changed

4 files changed

+152
-25
lines changed

scripts/benchmark_e2b.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
Benchmark script for the code_reward function with E2B.
17+
18+
This script measures the performance of the code_reward function with varying numbers
19+
of samples and parallelization levels.
20+
21+
Each sample is a CodeForces problem with a gold standard solution that is executed against a set of public test cases.
22+
"""
23+
24+
from datasets import load_dataset
25+
from open_r1.rewards import code_reward
26+
import time
27+
from tqdm.auto import tqdm
28+
29+
from dotenv import load_dotenv
30+
load_dotenv()
31+
32+
def benchmark_code_reward(example):
33+
start_time = time.time()
34+
test_completions = [[{"content": example["gold_standard_solution"]}]]
35+
reward_kwargs = {"verification_info": [example["verification_info"]]}
36+
rewards = code_reward(test_completions, **reward_kwargs)
37+
end_time = time.time()
38+
example["test_reward"] = rewards[0]
39+
example["reward_time"] = end_time - start_time
40+
return example
41+
42+
if __name__ == "__main__":
43+
parallel_dict = {
44+
16:[1,4,16],
45+
64:[4,16, 64],
46+
256:[16, 64, 96], # cap at 96 as PRO account is limited to 100
47+
}
48+
# Store results for table formatting
49+
results = []
50+
51+
for num_samples in tqdm([16, 64,256], desc="Benchmarking samples"):
52+
for num_parallel in parallel_dict[num_samples]:
53+
code_dataset = load_dataset("open-r1/verifiable-coding-problems-python_decontaminated")
54+
code_dataset = code_dataset["train"].shuffle(seed=42).select(range(num_samples))
55+
56+
test_completions = [[{"content": example["gold_standard_solution"]}] for example in code_dataset]
57+
reward_kwargs = {"verification_info": [example["verification_info"] for example in code_dataset]}
58+
59+
start_time = time.time()
60+
rewards = code_reward(test_completions, num_parallel=num_parallel, **reward_kwargs)
61+
execution_time = time.time() - start_time
62+
63+
# Calculate some statistics about rewards
64+
mean_reward = sum(rewards) / len(rewards)
65+
min_reward = min(rewards)
66+
max_reward = max(rewards)
67+
68+
# Store results
69+
results.append({
70+
"num_samples": num_samples,
71+
"num_parallel": num_parallel,
72+
"execution_time": execution_time,
73+
"mean_reward": mean_reward,
74+
"min_reward": min_reward,
75+
"max_reward": max_reward
76+
})
77+
78+
print("\n## Benchmark Results\n")
79+
print("| Sample Size | Parallelization | Execution Time (s) | Mean Reward | Min Reward | Max Reward |")
80+
print("|:-----------:|:---------------:|------------------:|:-----------:|:-----------:|:-----------:|")
81+
82+
for result in results:
83+
print(f"| {result['num_samples']:^11} | {result['num_parallel']:^15} | {result['execution_time']:17.2f} | {result['mean_reward']:^11.4f} | {result['min_reward']:^11.4f} | {result['max_reward']:^11.4f} |")
84+

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def deps_list(*pkgs):
9191
extras["quality"] = deps_list("ruff", "isort", "flake8")
9292
extras["code"] = deps_list("e2b-code-interpreter", "python-dotenv")
9393
extras["eval"] = deps_list("lighteval", "math-verify")
94-
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
94+
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] + extras["code"]
9595

9696
# core dependencies shared across the whole project - keep this to a bare minimum :)
9797
install_requires = [

src/open_r1/configs.py

+7
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ class GRPOScriptArguments(trl.ScriptArguments):
154154
"help": "for each generation, evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases. Useful to avoid overloading the eval server + save time on wrong solutions"
155155
},
156156
)
157+
parallel_code_exec_per_proc: int = field(
158+
default=2,
159+
metadata={
160+
"help": "Number of parallel E2B code executions per process. Default of 2 is suitable for the Free Hobby tier of E2B with 8 GPUs used for training."
161+
},
162+
)
163+
157164
dataset_prompt_column: str = field(
158165
default="prompt",
159166
metadata={"help": "Column to use as prompts for training."},

src/open_r1/rewards.py

+60-24
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
"""Reward functions for GRPO training."""
217

318
import asyncio
@@ -368,13 +383,13 @@ def extract_code(completion: str, language: str = "python") -> str:
368383
return extracted_answer
369384

370385

371-
def binary_code_reward(completions, **kwargs) -> list[float]:
372-
rewards = code_reward(completions, **kwargs)
386+
def binary_code_reward(completions, num_parallel: int = 2, **kwargs) -> list[float]:
387+
rewards = code_reward(completions, num_parallel=num_parallel, **kwargs)
373388
BINARY_THRESHOLD = 0.99
374389
return [1.0 if reward > BINARY_THRESHOLD else 0.0 for reward in rewards]
375390

376391

377-
def code_reward(completions, **kwargs) -> list[float]:
392+
def code_reward(completions, num_parallel: int = 2, **kwargs) -> list[float]:
378393
"""Reward function that evaluates code snippets using the E2B code interpreter.
379394
380395
Assumes the dataset contains a `verification_info` column with test cases.
@@ -438,7 +453,7 @@ def evaluate_code(code, test_cases):
438453
if not all(v["language"] == language for v in verification_info):
439454
raise ValueError("All verification_info must have the same language", verification_info)
440455
try:
441-
rewards = run_async_from_sync(scripts, language)
456+
rewards = run_async_from_sync(scripts, language, num_parallel)
442457

443458
except Exception as e:
444459
print(f"Error from E2B executor: {e}")
@@ -463,45 +478,62 @@ def code_format_reward(completions, **kwargs):
463478
return code_format_reward
464479

465480

466-
def run_async_from_sync(scripts: list[str], language: str) -> list[float]:
481+
def run_async_from_sync(scripts: list[str], language: str, num_parallel: int) -> list[float]:
467482
"""Function wrapping the `run_async` function."""
468483
# Create a new event loop and set it
469484
try:
470485
# Run the async function and get the result
471-
rewards = asyncio.run(run_async(scripts, language))
486+
rewards = asyncio.run(run_async(scripts, language, num_parallel))
472487
except Exception as e:
473488
print(f"Error from E2B executor async: {e}")
474489
raise e
475490

476491
return rewards
477492

478493

479-
async def run_async(scripts: list[str], language: str) -> list[float]:
480-
# Create the sandbox by hand, currently there's no context manager for this version
481-
sbx = await AsyncSandbox.create(timeout=30, request_timeout=3)
494+
async def run_async(scripts: list[str], language: str, num_parallel: int) -> list[float]:
495+
# Limit the number of concurrent tasks
496+
semaphore = asyncio.Semaphore(num_parallel)
482497

483498
# Create a list of tasks for running scripts concurrently
484-
tasks = [run_script(sbx, script, language) for script in scripts]
499+
tasks = [run_script(script, language, semaphore) for script in scripts]
485500

486501
# Wait for all tasks to complete and gather their results as they finish
487502
results = await asyncio.gather(*tasks)
488503
rewards = list(results) # collect results
489504

490-
# Kill the sandbox after all the tasks are complete
491-
await sbx.kill()
492-
493505
return rewards
494506

495507

496-
async def run_script(sbx: AsyncSandbox, script: str, language: str) -> float:
497-
execution = await sbx.run_code(script, language=language)
498-
try:
499-
return float(execution.text)
500-
except (TypeError, ValueError):
501-
return 0.0
502-
except Exception as e:
503-
print(f"Error from E2B executor run_script: {e}")
504-
return 0.0
508+
async def run_script(script: str, language: str, semaphore: asyncio.Semaphore) -> float:
509+
# We set a timeout margin, as the AsyncSandbox timeout does not seem to work
510+
# These values are based on running 256 examples with the gold solution
511+
# from open-r1/verifiable-coding-problems-python_decontaminated
512+
# see scripts/benchmark_e2b.py
513+
514+
SANDBOX_TIMEOUT = 30
515+
MARGIN = 2
516+
REQUEST_TIMEOUT = SANDBOX_TIMEOUT - MARGIN
517+
ASYNCIO_TIMEOUT = SANDBOX_TIMEOUT + MARGIN
518+
519+
async with semaphore:
520+
try:
521+
sandbox = await AsyncSandbox.create(timeout=SANDBOX_TIMEOUT, request_timeout=REQUEST_TIMEOUT)
522+
execution = await asyncio.wait_for(sandbox.run_code(script, language=language), timeout=ASYNCIO_TIMEOUT)
523+
return float(execution.text)
524+
except (TypeError, ValueError):
525+
return 0.0
526+
except asyncio.TimeoutError:
527+
print("Operation timed out")
528+
return 0.0
529+
except Exception as e:
530+
print(f"Error in `run_script` from E2B sandbox ID {sandbox.sandbox_id} : {e}")
531+
return 0.0
532+
finally:
533+
try:
534+
await sandbox.kill()
535+
except Exception as e:
536+
print(f"Error from E2B executor kill with sandbox ID {sandbox.sandbox_id} : {e}")
505537

506538

507539
def get_reward_funcs(script_args) -> list[Callable]:
@@ -521,8 +553,12 @@ def get_reward_funcs(script_args) -> list[Callable]:
521553
max_penalty=script_args.repetition_max_penalty,
522554
),
523555
"length": len_reward,
524-
"code": code_reward,
525-
"binary_code": binary_code_reward,
556+
"code": update_wrapper(
557+
partial(code_reward, num_parallel=script_args.parallel_code_exec_per_proc), code_reward
558+
),
559+
"binary_code": update_wrapper(
560+
partial(binary_code_reward, num_parallel=script_args.parallel_code_exec_per_proc), binary_code_reward
561+
),
526562
"ioi_code": update_wrapper(
527563
partial(ioi_code_reward, test_batch_size=script_args.code_eval_test_batch_size), ioi_code_reward
528564
),

0 commit comments

Comments
 (0)