Skip to content

Commit eafa233

Browse files
Lawhyyitianlian
andauthored
fix(examples): update strands_sglang example to strands-sglang v0.2.x API (#1593)
Co-authored-by: Chengxing Xie <91449279+yitianlian@users.noreply.github.com>
1 parent 14413cf commit eafa233

File tree

4 files changed

+20
-21
lines changed

4 files changed

+20
-21
lines changed

examples/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ These examples provide concrete examples to leverage slime in your own RL workfl
1414
- **[reproducibility](./reproducibility)**: Guides on achieving bitwise experiment reproduction using deterministic modes.
1515
- **[retool](./retool)**: Demonstrates the retool functionality for tool-enabled language model generation.
1616
- **[search-r1](./search-r1)**: A minimal reproduction of Search-R1, featuring multi-turn conversation and tool-calling.
17-
- **[strands-agents](./strands-agents)**: Integration example with the Strands-Agents scaffolding framework.
17+
- **[strands_sglang](./strands_sglang)**: Integration example with the Strands-Agents scaffolding framework.
1818
- **[tau-bench](./tau-bench)**: Training in an agentic multi-turn tool use environment (Tau-bench).
1919
- **[train_infer_mismatch_helper](./train_infer_mismatch_helper)**: Algorithmic methods for rollout correction (e.g., TIS, MIS).
2020
- **[true_on_policy](./true_on_policy)**: Ensures strictly equal log probabilities between inference (SGLang) and training engines.

examples/strands_sglang/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ This example connects `slime` with [`strands-sglang`](https://github.com/horizon
1414

1515
- Captures exact token IDs during generation (no retokenization drift)
1616
- Automatically tracks `loss_mask` via `token_manager`
17-
- Provides `ToolIterationLimiter` for clean trajectory truncation
17+
- Provides `ToolLimiter` for clean trajectory truncation
1818

1919
## Install Dependencies
2020

examples/strands_sglang/generate_with_strands.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from camel.interpreters import SubprocessInterpreter
44
from strands import Agent, tool
5-
from strands_sglang import SGLangClient, SGLangModel
6-
from strands_sglang.tool_limiter import ToolIterationLimiter
5+
from strands_sglang import SGLangClient, SGLangModel, ToolLimiter
6+
from strands_sglang.tool_parsers import HermesToolParser
77

88
from slime.rollout.rm_hub.math_dapo_utils import compute_score as math_dapo_compute_score
99
from slime.rollout.sglang_rollout import GenerateState
@@ -20,16 +20,17 @@
2020
- After completing your reasoning, present the final result enclosed in \\boxed{}.
2121
""".strip()
2222

23-
MAX_TOOL_ITERATIONS = 5
23+
MAX_TOOL_ITERS = 5
24+
MAX_TOOL_CALLS = None # No limit
2425

2526
_client_cache: dict[str, SGLangClient] = {}
2627

2728

2829
def get_client(args) -> SGLangClient:
29-
"""Get shared client for connection pooling (like SLIME)."""
30+
"""Get shared client for connection pooling."""
3031
base_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}"
3132
if base_url not in _client_cache:
32-
_client_cache[base_url] = SGLangClient.from_slime_args(args)
33+
_client_cache[base_url] = SGLangClient.from_slime_args(args, timeout=300.0)
3334
return _client_cache[base_url]
3435

3536

@@ -55,15 +56,15 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
5556
model = SGLangModel(
5657
tokenizer=state.tokenizer,
5758
client=get_client(args),
58-
model_id=args.hf_checkpoint.split("/")[-1],
59-
params={k: sampling_params[k] for k in ["max_new_tokens", "temperature", "top_p"]},
59+
tool_parser=HermesToolParser(), # tool parsing for wrapped JSON tool calls
60+
sampling_params=sampling_params,
6061
)
6162

62-
limiter = ToolIterationLimiter(max_iterations=MAX_TOOL_ITERATIONS)
63+
tool_limiter = ToolLimiter(max_tool_iters=MAX_TOOL_ITERS, max_tool_calls=MAX_TOOL_CALLS)
6364
agent = Agent(
6465
model=model,
6566
tools=[execute_python_code],
66-
hooks=[limiter],
67+
hooks=[tool_limiter],
6768
callback_handler=None,
6869
system_prompt=SYSTEM_PROMPT,
6970
)
@@ -74,12 +75,12 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
7475
await agent.invoke_async(prompt)
7576
sample.status = Sample.Status.COMPLETED
7677
except Exception as e:
77-
# Always use TRUNCATED instead of ABORTED because Slime doesn't properly
78+
# Always use TRUNCATED instead of ABORTED because slime doesn't properly
7879
# handle ABORTED samples in reward processing. See: https://github.com/THUDM/slime/issues/200
7980
sample.status = Sample.Status.TRUNCATED
8081
logger.warning(f"TRUNCATED: {type(e).__name__}: {e}")
8182

82-
# TITO: extract trajectory from token_manager
83+
# Extract token trajectory from token_manager
8384
tm = model.token_manager
8485
prompt_len = len(tm.segments[0]) # system + user are first segment
8586
sample.tokens = tm.token_ids
@@ -88,9 +89,8 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
8889
sample.response_length = len(sample.tokens) - prompt_len
8990
sample.response = model.tokenizer.decode(sample.tokens[prompt_len:], skip_special_tokens=False)
9091
# Tool iteration and tool call count are different because multiple parallel tool calls count as 1 iteration
91-
sample.tool_iterations = limiter.iteration_count
92-
trajectory = model.format_request_messages(agent.messages, None)
93-
sample.tool_call_count = [message["role"] == "tool" for message in trajectory].count(True)
92+
sample.tool_iters = tool_limiter.tool_iter_count
93+
sample.tool_calls = tool_limiter.tool_call_count
9494

9595
model.reset()
9696
agent.cleanup()
@@ -100,18 +100,19 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
100100
async def reward_func(args, sample: Sample, **kwargs):
101101
"""Reward function using math_dapo scoring."""
102102
ground_truth = sample.label or ""
103-
tool_iterations = getattr(sample, "tool_iterations", 0)
103+
tool_iters = getattr(sample, "tool_iters", 0)
104+
tool_calls = getattr(sample, "tool_calls", 0)
104105

105106
result = math_dapo_compute_score(sample.response, ground_truth, strict_box_verify=False)
106107
if result["pred"] == "[INVALID]":
107108
result = math_dapo_compute_score(sample.response, ground_truth, strict_box_verify=True)
108109

109110
# Encourage tool use on failures
110111
if result["score"] < 0:
111-
result["score"] = min(-0.6, result["score"] + (tool_iterations - 2) / 2 * 0.1)
112+
result["score"] = min(-0.6, result["score"] + (tool_iters - 2) / 2 * 0.1)
112113

113114
result["pred"] = result["pred"] or ""
114115
logger.info(
115-
f"reward={result['score']:.2f} | status={sample.status.name} | tool_iters={tool_iterations} | tool_calls={getattr(sample, 'tool_call_count', 0)} | tokens={len(sample.tokens)} | resp_len={sample.response_length} | "
116+
f"reward={result['score']:.2f} | status={sample.status.name} | tool_iters={tool_iters} | tool_calls={tool_calls} | tokens={len(sample.tokens)} | resp_len={sample.response_length} | "
116117
)
117118
return result["score"]
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
11
camel-ai
2-
strands-agents
3-
strands-agents-tools
42
strands-sglang

0 commit comments

Comments
 (0)