1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
"""Reward functions for GRPO training."""
2
17
3
18
import asyncio
@@ -368,13 +383,13 @@ def extract_code(completion: str, language: str = "python") -> str:
368
383
return extracted_answer
369
384
370
385
371
- def binary_code_reward (completions , ** kwargs ) -> list [float ]:
372
- rewards = code_reward (completions , ** kwargs )
386
+ def binary_code_reward (completions , num_parallel : int = 2 , ** kwargs ) -> list [float ]:
387
+ rewards = code_reward (completions , num_parallel = num_parallel , ** kwargs )
373
388
BINARY_THRESHOLD = 0.99
374
389
return [1.0 if reward > BINARY_THRESHOLD else 0.0 for reward in rewards ]
375
390
376
391
377
- def code_reward (completions , ** kwargs ) -> list [float ]:
392
+ def code_reward (completions , num_parallel : int = 2 , ** kwargs ) -> list [float ]:
378
393
"""Reward function that evaluates code snippets using the E2B code interpreter.
379
394
380
395
Assumes the dataset contains a `verification_info` column with test cases.
@@ -438,7 +453,7 @@ def evaluate_code(code, test_cases):
438
453
if not all (v ["language" ] == language for v in verification_info ):
439
454
raise ValueError ("All verification_info must have the same language" , verification_info )
440
455
try :
441
- rewards = run_async_from_sync (scripts , language )
456
+ rewards = run_async_from_sync (scripts , language , num_parallel )
442
457
443
458
except Exception as e :
444
459
print (f"Error from E2B executor: { e } " )
@@ -463,45 +478,62 @@ def code_format_reward(completions, **kwargs):
463
478
return code_format_reward
464
479
465
480
466
- def run_async_from_sync (scripts : list [str ], language : str ) -> list [float ]:
481
+ def run_async_from_sync (scripts : list [str ], language : str , num_parallel : int ) -> list [float ]:
467
482
"""Function wrapping the `run_async` function."""
468
483
# Create a new event loop and set it
469
484
try :
470
485
# Run the async function and get the result
471
- rewards = asyncio .run (run_async (scripts , language ))
486
+ rewards = asyncio .run (run_async (scripts , language , num_parallel ))
472
487
except Exception as e :
473
488
print (f"Error from E2B executor async: { e } " )
474
489
raise e
475
490
476
491
return rewards
477
492
478
493
479
- async def run_async (scripts : list [str ], language : str ) -> list [float ]:
480
- # Create the sandbox by hand, currently there's no context manager for this version
481
- sbx = await AsyncSandbox . create ( timeout = 30 , request_timeout = 3 )
494
+ async def run_async (scripts : list [str ], language : str , num_parallel : int ) -> list [float ]:
495
+ # Limit the number of concurrent tasks
496
+ semaphore = asyncio . Semaphore ( num_parallel )
482
497
483
498
# Create a list of tasks for running scripts concurrently
484
- tasks = [run_script (sbx , script , language ) for script in scripts ]
499
+ tasks = [run_script (script , language , semaphore ) for script in scripts ]
485
500
486
501
# Wait for all tasks to complete and gather their results as they finish
487
502
results = await asyncio .gather (* tasks )
488
503
rewards = list (results ) # collect results
489
504
490
- # Kill the sandbox after all the tasks are complete
491
- await sbx .kill ()
492
-
493
505
return rewards
494
506
495
507
496
- async def run_script (sbx : AsyncSandbox , script : str , language : str ) -> float :
497
- execution = await sbx .run_code (script , language = language )
498
- try :
499
- return float (execution .text )
500
- except (TypeError , ValueError ):
501
- return 0.0
502
- except Exception as e :
503
- print (f"Error from E2B executor run_script: { e } " )
504
- return 0.0
508
+ async def run_script (script : str , language : str , semaphore : asyncio .Semaphore ) -> float :
509
+ # We set a timeout margin, as the AsyncSandbox timeout does not seem to work
510
+ # These values are based on running 256 examples with the gold solution
511
+ # from open-r1/verifiable-coding-problems-python_decontaminated
512
+ # see scripts/benchmark_e2b.py
513
+
514
+ SANDBOX_TIMEOUT = 30
515
+ MARGIN = 2
516
+ REQUEST_TIMEOUT = SANDBOX_TIMEOUT - MARGIN
517
+ ASYNCIO_TIMEOUT = SANDBOX_TIMEOUT + MARGIN
518
+
519
+ async with semaphore :
520
+ try :
521
+ sandbox = await AsyncSandbox .create (timeout = SANDBOX_TIMEOUT , request_timeout = REQUEST_TIMEOUT )
522
+ execution = await asyncio .wait_for (sandbox .run_code (script , language = language ), timeout = ASYNCIO_TIMEOUT )
523
+ return float (execution .text )
524
+ except (TypeError , ValueError ):
525
+ return 0.0
526
+ except asyncio .TimeoutError :
527
+ print ("Operation timed out" )
528
+ return 0.0
529
+ except Exception as e :
530
+ print (f"Error in `run_script` from E2B sandbox ID { sandbox .sandbox_id } : { e } " )
531
+ return 0.0
532
+ finally :
533
+ try :
534
+ await sandbox .kill ()
535
+ except Exception as e :
536
+ print (f"Error from E2B executor kill with sandbox ID { sandbox .sandbox_id } : { e } " )
505
537
506
538
507
539
def get_reward_funcs (script_args ) -> list [Callable ]:
@@ -521,8 +553,12 @@ def get_reward_funcs(script_args) -> list[Callable]:
521
553
max_penalty = script_args .repetition_max_penalty ,
522
554
),
523
555
"length" : len_reward ,
524
- "code" : code_reward ,
525
- "binary_code" : binary_code_reward ,
556
+ "code" : update_wrapper (
557
+ partial (code_reward , num_parallel = script_args .parallel_code_exec_per_proc ), code_reward
558
+ ),
559
+ "binary_code" : update_wrapper (
560
+ partial (binary_code_reward , num_parallel = script_args .parallel_code_exec_per_proc ), binary_code_reward
561
+ ),
526
562
"ioi_code" : update_wrapper (
527
563
partial (ioi_code_reward , test_batch_size = script_args .code_eval_test_batch_size ), ioi_code_reward
528
564
),
0 commit comments