Skip to content

Commit c6e7fd0

Browse files
committed
feat: add sandbox fusion tools for multi-turn
1 parent a3c4cb3 commit c6e7fd0

File tree

12 files changed

+1176
-81
lines changed

12 files changed

+1176
-81
lines changed

.github/workflows/sgl.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,7 @@ jobs:
8282
run: |
8383
cd tests/workers/rollout
8484
torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_async_rollout_w_tools.py
85+
- name: Test the latest SGLang Rollout async with sandbox fusion tool
86+
run: |
87+
cd tests/workers/rollout
88+
pytest -s test_sglang_async_rollout_sf_tools.py
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
===============================
2+
Sandbox Fusion Tool Integration
3+
===============================
4+
5+
Motivations
6+
===========
7+
8+
- As users of veRL, we want to allow the model to call certain tools during Actor rollout, incorporating the results into the training process.
9+
- A colleague from ByteDance proposed a paper aimed at enhancing model capability through code execution tools.
10+
- We aim to support tool-calling capabilities of inference engines using `sandbox-fusion` as the code execution system, providing the community with a reimplementation of `retools`.
11+
12+
Reward Compute with Sandbox Fusion + FaaS Integration
13+
=====================================================
14+
15+
- In current datasets and tasks, similar work already exists (e.g., Prime), which uses local processes as runners to execute model-generated code for reward computation.
16+
- On this basis, #1429 has advanced the design by integrating FaaS as the runner for reward computation.
17+
18+
Goals
19+
=====
20+
21+
- Adapt to the `sglang` tool-calling protocol and define tools for sandbox fusion.
22+
- Integrate with the `async-rollout` process, ensuring sandbox fusion tools follow asyncIO conventions.
23+
- Design and implement a basic rate limiter to prevent issues such as 429 errors.
24+
25+
Non-Goals
26+
=========
27+
28+
- Training effectiveness is out of scope.
29+
- Observability metrics are not considered.
30+
- Distributed failover and component fault tolerance are not addressed.
31+
32+
Design Details
33+
==============
34+
35+
Tool Schema Definition
36+
----------------------
37+
38+
- Currently, only code execution is considered, requiring a `code` field in the JSON from the model.
39+
- Only Python code is supported for now, so no `language` parameter is defined.
40+
41+
.. code-block:: python
42+
43+
OpenAIFunctionToolSchema(
44+
type="function",
45+
function=OpenAIFunctionSchema(
46+
name="code_interpreter",
47+
description="A tool for executing code.",
48+
parameters=OpenAIFunctionParametersSchema(
49+
type="object",
50+
properties={
51+
"code": OpenAIFunctionPropertySchema(
52+
type="string",
53+
description="The code to execute.",
54+
enum=None,
55+
)
56+
},
57+
required=["code"],
58+
),
59+
strict=False,
60+
)
61+
)
62+
63+
Configuration Parameters
64+
------------------------
65+
66+
+----------------------------+--------------------------------------------------------------+
67+
| Parameter Name | Description |
68+
+============================+==============================================================+
69+
| `num_workers` | Number of worker threads/processes per DP to request runner. |
70+
+----------------------------+--------------------------------------------------------------+
71+
| `rate_limit` | Global limit of concurrent code executions. Default: 10 |
72+
+----------------------------+--------------------------------------------------------------+
73+
| `default_timeout` | Timeout (in seconds) for each code execution. Default: 30 |
74+
+----------------------------+--------------------------------------------------------------+
75+
| `default_language` | Default programming language. Default: "python" |
76+
+----------------------------+--------------------------------------------------------------+
77+
| `enable_global_rate_limit` | Whether to enable global rate limiting. Default: True |
78+
+----------------------------+--------------------------------------------------------------+
79+
| `sandbox_fusion_url` | URL for the veFaas sandbox execution service |
80+
+----------------------------+--------------------------------------------------------------+
81+
82+
Rate Limiting Design
83+
---------------------
84+
85+
Objective:
86+
87+
- Limit the number of inflight requests using a token bucket model.
88+
89+
- Ensure ordered submission to code runners to avoid starvation due to backoff.
90+
91+
Design Highlights:
92+
93+
- Use Ray Global Actor as a singleton distributed counter at cluster level.
94+
95+
- Semaphore used for counting, with `acquire` and `release` in separate thread pools to preserve order.
96+
97+
- Use Ray’s cloud-pickle to serialize functions for decoupled `ExecutionWorker`.
98+
99+
.. code-block:: python
100+
101+
@ray.remote(concurrency_groups={"acquire": 1,"release": 10})
102+
class TokenBucketWorker:
103+
def __init__(self, rate_limit: int):
104+
self.rate_limit = rate_limit
105+
self.current_count = 0
106+
self._semaphore = threading.Semaphore(rate_limit)
107+
108+
@ray.method(concurrency_group="acquire")
109+
def acquire(self):
110+
self._semaphore.acquire()
111+
self.current_count += 1
112+
113+
@ray.method(concurrency_group="release")
114+
def release(self):
115+
self._semaphore.release()
116+
self.current_count -= 1
117+
118+
def get_current_count(self):
119+
return self.current_count
120+
121+
class ExecutionWorker:
122+
def __init__(self, enable_global_rate_limit=True, rate_limit=10):
123+
self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None
124+
125+
def _init_rate_limit(self, rate_limit):
126+
return TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote(rate_limit)
127+
128+
def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T:
129+
with ExitStack() as stack:
130+
stack.callback(self.rate_limit_worker.release.remote)
131+
ray.get(self.rate_limit_worker.acquire.remote())
132+
try:
133+
return fn(*fn_args, **fn_kwargs)
134+
except Exception as e:
135+
logger.warning(f"Error when executing code: {e}")
136+
137+
def init_execution_pool(num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode=PoolMode.ThreadMode):
138+
if mode == PoolMode.ThreadMode:
139+
return ray.remote(ExecutionWorker).options(max_concurrency=num_workers).remote(
140+
enable_global_rate_limit=enable_global_rate_limit,
141+
rate_limit=rate_limit
142+
)
143+
else:
144+
raise NotImplementedError("Process mode is not implemented yet")
145+
146+
Tool Implementation
147+
-------------------
148+
149+
- Use `instance_id` to identify requests across multiple dialogue rounds.
150+
151+
- Use `execution_pool` to implement async invocation.
152+
153+
- Cleanup state after rollout completion.
154+
155+
.. code-block:: python
156+
157+
class SandboxFusionTool(BaseTool):
158+
def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
159+
...
160+
self.execution_pool = init_execution_pool(...)
161+
...
162+
163+
async def create(self, instance_id: Optional[str] = None, ...):
164+
...
165+
166+
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
167+
code = parameters.get("code", "")
168+
timeout = parameters.get("timeout", self.default_timeout)
169+
language = parameters.get("language", self.default_language)
170+
if not isinstance(code, str):
171+
code = str(code)
172+
173+
result = await self.execution_pool.execute.remote(self.execute_code,instance_id,code,timeout,language)
174+
self._instance_dict[instance_id]["reward"].append(result.strip())
175+
176+
return result, result, {}
177+
178+
def execute_code(self,instance_id,code,timeout=30,language="python"):
179+
result_status, metadata = _process_single_case(0, None, None,self.sandbox_fusion_url, code, timeout, language)
180+
# we should always expect this since we don't have correct answer
181+
if metadata["run_status"] == "Finished":
182+
actual_output = metadata["stdout"] if metadata["stdout"] is not None else ""
183+
return actual_output
184+
else:
185+
return "no stdout here"
186+
187+
async def calc_reward(self, instance_id: str, ...):
188+
...
189+
190+
async def release(self, instance_id: str, ...):
191+
...
192+
193+
Test Plan
194+
=========
195+
196+
Unit Tests
197+
----------
198+
199+
- **test_tools_registration**: Test tool registration and initialization.
200+
- **test_rollout_req_creation**: Validate that `AsyncRolloutReq` is built correctly.
201+
- **test_over_size_case**: Ensure rollout terminates early when exceeding `max_seq_len`.
202+
- **test_tool_call_basic_case**: Mock `sglang` output, validate tool call and result.
203+
- **test_tool_call_batch_case**: Test batch processing of tool calls.
204+
- **test_basic_multi_process_init**: Validate Ray global actor behaves as singleton.
205+
- **TestSingleNodeRateLimiterCase**: Verify rate limiter works in single-node mode.
206+
- **test_rotten_execution**: Ensure rate limiter recovers from function errors.
207+
- **TestMultiNodeRateLimiterCase**: Verify behavior in multi-node environments.
208+
209+
e2e Tests
210+
----------
211+
we provide e2e test scripts in `tests/e2e` folder, named `tests/e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh`
212+
213+
here is a sample with tool calling:
214+
215+
{
216+
"input": "
217+
218+
system\nYou are a math expert. You are given a question and you need to solve it step by step. Reasoning step by step before any tool call. You should use the `calc_gsm8k_reward` tool after step by step solving the question, before generate final answer at least once and refine your answer if necessary. Put your final answer in the format of `#### <answer>`.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"code_interpreter\", \"description\": \"A tool for executing code.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"code\": {\"type\": \"string\", \"description\": \"The code to execute.\", \"enum\": null}}, \"required\": [\"code\"]}, \"strict\": false}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>\n
219+
220+
user\nJohn gets a bonus that's the same percentage every year. Last year he made $100,000 and got a $10,000 bonus. This year he makes $200,000. How much will John make this year when adding both his total pay and bonus together? Let's think step by step and output the final answer after `####`.\n
221+
222+
223+
assistant\n",
224+
"output": "assistant\n<think>\nOkay, let's see. John gets a bonus that's the same percentage every year. Last year, he made $100,000 and got a $10,000 bonus. This year, he makes $200,000. The question is asking how much he'll make this year when adding both his total pay and bonus together.\n\nFirst, I need to figure out the percentage bonus he gets. Since the bonus is the same percentage each year, if I can find that percentage based on last year's details, I can apply it to this year's earnings to find this year's bonus. Then, I'll add that bonus to his total pay this year.\n\nLast year, he earned $100,000 and got a $10,000 bonus. To find the percentage, I can divide the bonus by his total earnings. So, 10,000 divided by 100,000. Let me calculate that.\n\n10,000 / 100,000 equals 0.1, which is 10% as a percentage. So, his bonus is 10% of his pay each year.\n\nThis year, he makes $200,000. So, the bonus this year would be 10% of 200,000. Let me compute that using the code interpreter to ensure accuracy.\n
225+
226+
227+
<tool_call>\n{\"name\": \"code_interpreter\", \"arguments\": {\"code\": \"total_pay_this_year = 200000\\nbonus_percentage = 10 / 100\\nbonus_this_year = total_pay_this_year * bonus_percentage\\ntotal_income_this_year = total_pay_this_year + bonus_this_year\\nprint(total_income_this_year)\", \"executes\": \"True\"}}\n</tool_call>\n
228+
229+
tool\n220000.0\n\n
230+
231+
assistant\n<think>\nSo the code shows that John's total income this year would be $220,000. This is consistent with the 10% bonus on last year's pay of $100,000, which was $10,000. Therefore, this year's bonus of $20,000 added to his total pay of $200,000 results in a total income of $220,000.\n\nI need to verify that there are no other factors at play. The problem states it's a constant percentage bonus each year, so there's no mention of changing percentages or other variables. Therefore, applying a straightforward percentage of 10% on this year's partial salary amount would result in $20,000. Adding that to his total pay gives $220,000. The calculation is correct and aligns with the question's parameters.\n</think>\n\n#### 220000.0",
232+
"score": 0,
233+
"step": 1
234+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
hydra:
2+
searchpath:
3+
- file://verl/trainer/config
4+
5+
defaults:
6+
- ppo_trainer
7+
- _self_
8+
9+
data:
10+
max_prompt_length: 1024
11+
max_response_length: 1024
12+
train_batch_size: 256
13+
return_raw_chat: True
14+
15+
actor_rollout_ref:
16+
hybrid_engine: True
17+
rollout:
18+
name: sglang_async
19+
multi_turn:
20+
enable: True
21+
max_turns: 5
22+
tool_config_path: "./config/tool_config/sandbox_fusion_tool_config.yaml"
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
tools:
2+
- class_name: "verl.tools.sandbox_fusion_tools.SandboxFusionTool"
3+
config: {
4+
"sandbox_fusion_url": "https://xx.apigateway-cn-beijing.volceapi.com/run_code"
5+
}
6+
tool_schema:
7+
type: "function"
8+
function:
9+
name: "code_interpreter"
10+
description: "A tool for executing code."
11+
parameters:
12+
type: "object"
13+
properties:
14+
code:
15+
type: "string"
16+
description: "The code to execute."
17+
required: ["code"]
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# run on 8xH20
2+
# make sure your current working directory is the root of the project
3+
4+
set -x
5+
6+
# huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen/Qwen2.5-3B-Instruct
7+
8+
ulimit -n 65535
9+
10+
PROJECT_DIR="$(pwd)"
11+
CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"
12+
13+
python3 -m verl.trainer.main_ppo \
14+
--config-path="$CONFIG_PATH" \
15+
--config-name='gsm8k_multiturn_sf_grpo' \
16+
algorithm.adv_estimator=grpo \
17+
data.train_batch_size=512 \
18+
data.max_prompt_length=1024 \
19+
data.max_response_length=1024 \
20+
data.filter_overlong_prompts=True \
21+
data.truncation='error' \
22+
data.return_raw_chat=True \
23+
actor_rollout_ref.model.path=swordfaith/ReTool-Qwen3-4B-SFT-cold-started \
24+
actor_rollout_ref.actor.optim.lr=1e-6 \
25+
actor_rollout_ref.model.use_remove_padding=True \
26+
actor_rollout_ref.actor.ppo_mini_batch_size=512 \
27+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
28+
actor_rollout_ref.actor.use_kl_loss=True \
29+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
30+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
31+
actor_rollout_ref.actor.entropy_coeff=0 \
32+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
33+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
34+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
35+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
36+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
37+
actor_rollout_ref.rollout.name=sglang_async \
38+
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
39+
actor_rollout_ref.rollout.n=8 \
40+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
41+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
42+
algorithm.use_kl_in_reward=False \
43+
trainer.critic_warmup=0 \
44+
trainer.logger=['console'] \
45+
trainer.project_name='gsm8k_async_rl' \
46+
trainer.rollout_data_dir='/dev/shm/rollout_data' \
47+
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-sf-tool' \
48+
trainer.n_gpus_per_node=8 \
49+
trainer.nnodes=1 \
50+
trainer.save_freq=-1 \
51+
trainer.test_freq=-1 \
52+
data.train_files=/demo-huabei2/wxc/workspace/src/github-fork-verl/verl/tests/workers/rollout/resource/gsm8k_verl_sgl_multi_turn_preprocessed/train.parquet \
53+
data.val_files=/demo-huabei2/wxc/workspace/src/github-fork-verl/verl/tests/workers/rollout/resource/gsm8k_verl_sgl_multi_turn_preprocessed/test.parquet \
54+
actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml" \
55+
trainer.val_before_train=False \
56+
trainer.total_training_steps=5 $@
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
tools:
2+
- class_name: "verl.tools.sandbox_fusion_tools.SandboxFusionTool"
3+
config: {
4+
"sandbox_fusion_url": "https://xxx.apigateway-cn-beijing.volceapi.com/run_code"
5+
}
6+
tool_schema:
7+
type: "function"
8+
function:
9+
name: "code_interpreter"
10+
description: "A tool for executing code."
11+
parameters:
12+
type: "object"
13+
properties:
14+
code:
15+
type: "string"
16+
description: "The code to execute."
17+
required: ["code"]

0 commit comments

Comments
 (0)