Skip to content

Commit f600569

Browse files
authored
Add strands-agents example (#976)
1 parent 763f18d commit f600569

File tree

4 files changed

+484
-0
lines changed

4 files changed

+484
-0
lines changed

examples/strands-agents/README.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Slime x Strands-Agents
2+
3+
This is a running example that connects the [Strands-Agents](https://github.com/strands-agents/sdk-python) agent scaffolding framework with Slime for RL training.
4+
5+
## Install Dependencies
6+
7+
1. Pull the `slimerl/slime:latest` image and enter it
8+
2. Goes to slime folder: `cd /root/slime` (Clone the repository if not already there: `cd /root && git clone https://github.com/THUDM/slime.git`)
9+
3. Install Slime: `pip install -e .`
10+
4. Goes to the example folder: `cd /root/slime/examples/strands-agents`
11+
5. Install other dependencies: `pip install -r requirements.txt`
12+
13+
> NOTE: we use camel-ai's subprocess code interpreter for python code execution, which is NOT a good practice; it's just for convenience of this example and the dependencies for solving math problems are usually ready in `slime`'s docker
14+
15+
## Prepare Model
16+
17+
```bash
18+
# hf checkpoint
19+
huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 --local-dir /root/models/Qwen/Qwen3-4B-Instruct-2507
20+
21+
# mcore checkpoint
22+
cd /root/slime
23+
source scripts/models/qwen3-4B.sh
24+
PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \
25+
${MODEL_ARGS[@]} \
26+
--hf-checkpoint /root/models/Qwen/Qwen3-4B-Instruct-2507 \
27+
--save /root/models/Qwen/Qwen3-4B-Instruct-2507_torch_dist
28+
```
29+
30+
## Prepare Dataset
31+
32+
Following [Retool](https://arxiv.org/abs/2504.11536), we used `dapo-math-17k` as training data:
33+
34+
```
35+
from datasets import load_dataset
36+
ds = load_dataset("zhuzilin/dapo-math-17k", split="train")
37+
ds.to_json("/root/data/dapo-math-17k.jsonl", orient="records", lines=True)
38+
```
39+
40+
and `aime-2024` as eval data:
41+
42+
```
43+
from datasets import load_dataset
44+
ds = load_dataset("zhuzilin/aime-2024", split="train")
45+
ds.to_json("/root/data/aime-2024.jsonl", orient="records", lines=True)
46+
```
47+
48+
## Run Training
49+
50+
Assuming `/root/slime` is up-to-date (if this PR is not merged you may need to switch branch):
51+
52+
```
53+
cd /root/slime
54+
export WANDB_KEY=$your_wandb_key
55+
bash examples/strands-agents/strands_qwen3_4b.sh
56+
```
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
import logging
2+
3+
import openai
4+
import wandb
5+
from camel.interpreters import SubprocessInterpreter
6+
from strands import Agent, tool
7+
from strands.models.openai import OpenAIModel
8+
from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, MaxTokensReachedException
9+
10+
from slime.rollout.rm_hub.math_dapo_utils import compute_score as math_dapo_compute_score
11+
from slime.rollout.sglang_rollout import GenerateState
12+
from slime.utils.types import Sample
13+
14+
logging.basicConfig(level=logging.INFO)
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
SYSTEM_PROMPT = """
20+
You are a helpful math-solving assistant with access to the `execute_python_code` tool.
21+
22+
Guidelines:
23+
- For any numerical or symbolic computation, always use the `execute_python_code` tool rather than performing calculations mentally.
24+
- Break problems into clear steps, calling the Python tool whenever computation is required.
25+
- After completing your reasoning, present the final result enclosed in \\boxed{}.
26+
""".strip()
27+
28+
MAX_NUM_MESSAGES = 16 # messages beyond this will be truncated
29+
30+
31+
def create_strands_agent(args, sampling_params):
32+
"""Create a strands agent that connects to the SGLang rollout server"""
33+
34+
# Create an OpenAI model from the SGLang server
35+
model_params = {
36+
"max_tokens": sampling_params["max_new_tokens"],
37+
"temperature": sampling_params["temperature"],
38+
"top_p": sampling_params["top_p"],
39+
}
40+
sglang_server_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/v1"
41+
logger.info(
42+
f"[Strands Agents] Creating OpenAIModel from SGLang server at {sglang_server_url}"
43+
f" with parameters: {model_params}"
44+
)
45+
model = OpenAIModel(
46+
client_args={
47+
"api_key": "EMPTY",
48+
"base_url": sglang_server_url,
49+
"timeout": 300.0, # needed for tool calls
50+
},
51+
model_id=args.hf_checkpoint.split("/")[-1],
52+
params=model_params,
53+
)
54+
55+
# Define the `execute_python_code` tool using camel-ai's subprocess interpreter
56+
@tool
57+
def execute_python_code(code: str) -> str:
58+
r"""Execute a given Python code snippet.
59+
60+
Args:
61+
code (str): The input Python code to the Code Execution tool call.
62+
63+
Returns:
64+
str: The text output from the Code Execution tool call.
65+
"""
66+
interpreter = SubprocessInterpreter(
67+
require_confirm=False,
68+
print_stdout=False,
69+
print_stderr=False,
70+
execution_timeout=60.0,
71+
)
72+
result = interpreter.run(code=code, code_type="python")
73+
logger.info(
74+
f"[Strands Agents] executing Python code: ```python\n{code}\n``` and get execution result: ```python\n{result}\n```"
75+
)
76+
return result
77+
78+
# Create the strands agent
79+
agent = Agent(
80+
model=model,
81+
tools=[execute_python_code],
82+
system_prompt=SYSTEM_PROMPT,
83+
callback_handler=None,
84+
)
85+
86+
return agent
87+
88+
89+
async def run_strands_agent(agent: Agent, prompt: str) -> Sample.Status:
90+
"""Run the strands agent with the given prompt and set the sample status."""
91+
try:
92+
logger.info(f"[Strands Agents] running agent with prompt: {prompt}")
93+
await agent.invoke_async(prompt=prompt)
94+
sample_status = Sample.Status.COMPLETED
95+
except Exception as e:
96+
truncated_conditions = [
97+
isinstance(e, MaxTokensReachedException),
98+
isinstance(e, ContextWindowOverflowException),
99+
isinstance(e, EventLoopException)
100+
and isinstance(e.original_exception, openai.APIError)
101+
and "context length" in str(e.original_exception).lower(),
102+
]
103+
if any(truncated_conditions):
104+
sample_status = Sample.Status.TRUNCATED
105+
logger.warning(f"[Strands Agents] sample is TRUNCATED due to {type(e).__name__}: {e}")
106+
else:
107+
sample_status = Sample.Status.ABORTED
108+
logger.error(f"[Strands Agents] sample is ABORTED due to {type(e).__name__}: {e}")
109+
110+
return sample_status
111+
112+
113+
def get_trajectory(agent: Agent) -> list[dict]:
114+
"""Get the chat template-compatible trajectory from strands agent's messages."""
115+
openai_model: OpenAIModel = agent.model
116+
trajectory = openai_model.format_request_messages(messages=agent.messages, system_prompt=agent.system_prompt)
117+
for message in trajectory:
118+
if "content" in message and isinstance(message["content"], list):
119+
if len(message["content"]) > 0 and "text" in message["content"][0]:
120+
message["content"] = message["content"][0]["text"]
121+
else:
122+
message["content"] = ""
123+
return trajectory
124+
125+
126+
async def generate(args, sample: Sample, sampling_params) -> Sample:
127+
"""Generate function using strands-agents as agent scaffolding"""
128+
assert not args.partial_rollout, "Partial rollout is not supported for this function at the moment."
129+
130+
state = GenerateState(args)
131+
132+
# Create strands agent
133+
agent = create_strands_agent(args, sampling_params)
134+
135+
# Run the strands agent
136+
prompt_text = sample.prompt if isinstance(sample.prompt, str) else sample.prompt[0]["content"]
137+
sample.status = await run_strands_agent(agent, prompt_text)
138+
139+
# Early return if sample is aborted
140+
if sample.status == Sample.Status.ABORTED:
141+
agent.cleanup()
142+
return sample
143+
144+
# Get the trajectory from the agent and further truncate if necessary
145+
trajectory = get_trajectory(agent)
146+
if len(trajectory) > MAX_NUM_MESSAGES:
147+
logger.warning(
148+
f"[Strands Agents] sample is TRUNCATED due to number of messages (={len(trajectory)}) exceeding limit (={MAX_NUM_MESSAGES})"
149+
)
150+
# This post-processing is not optimal but just for simplicity
151+
# We should implement a hook in strands-agents to handle this truncation
152+
trajectory = trajectory[:MAX_NUM_MESSAGES]
153+
sample.status = Sample.Status.TRUNCATED
154+
155+
# Get the initial prompt (system + user message)
156+
initial_prompt_messages = [msg for msg in trajectory if msg["role"] in ["system", "user"]]
157+
assert len(initial_prompt_messages) == 2, "Initial prompt messages must be exactly 2 for single-turn conversations"
158+
prompt_text = state.tokenizer.apply_chat_template(
159+
initial_prompt_messages,
160+
tokenize=False,
161+
add_generation_prompt=True, # Add generation prompt for the assistant
162+
)
163+
prompt_tokens_ids = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
164+
165+
# Build (re-tokenize) the response incrementally
166+
response_token_ids = []
167+
loss_masks = []
168+
response_text = ""
169+
170+
# Start with the initial prompt messages for progressive chat template application
171+
current_messages = list(initial_prompt_messages)
172+
prev_token_count = len(prompt_tokens_ids)
173+
174+
# Iterate through remaining messages (assistant and tool messages)
175+
for message in trajectory[len(initial_prompt_messages) :]:
176+
# Add this message to the conversation
177+
current_messages.append(message)
178+
179+
# Apply chat template and tokenize up to this point
180+
current_text = state.tokenizer.apply_chat_template(
181+
current_messages, tokenize=False, add_generation_prompt=False
182+
)
183+
current_token_ids = state.tokenizer(current_text, add_special_tokens=False)["input_ids"]
184+
185+
# Calculate how many new tokens this message added
186+
new_token_count = len(current_token_ids)
187+
message_token_length = new_token_count - prev_token_count
188+
189+
# Extract the new tokens for this message
190+
message_tokens = current_token_ids[prev_token_count:]
191+
assert len(message_tokens) == message_token_length, "Message tokens length mismatch"
192+
response_token_ids.extend(message_tokens)
193+
194+
# Align message tokens with loss masks
195+
if message["role"] == "assistant":
196+
# We train on assistant messages
197+
loss_masks.extend([1] * message_token_length)
198+
else:
199+
# We don't train on tool messages
200+
loss_masks.extend([0] * message_token_length)
201+
202+
prev_token_count = new_token_count
203+
204+
# Extract the response text (everything after the initial prompt)
205+
full_conversation_text = state.tokenizer.apply_chat_template(
206+
trajectory, tokenize=False, add_generation_prompt=False
207+
)
208+
response_text = full_conversation_text[len(prompt_text) :]
209+
210+
# Set sample attributes and some debug information
211+
sample.tokens = prompt_tokens_ids + response_token_ids
212+
sample.response_length = len(response_token_ids)
213+
sample.response = response_text
214+
sample.loss_mask = loss_masks
215+
# Store tool call count for reward calculation
216+
sample.tool_call_count = [message["role"] == "tool" for message in trajectory].count(True)
217+
218+
# Log to wandb if available
219+
if wandb.run is not None:
220+
wandb.log(
221+
{
222+
"debug/response_length": sample.response_length,
223+
"debug/available_tools": len(agent.tool_names),
224+
"debug/tool_calls": sample.tool_call_count,
225+
"debug/num_messages": len(trajectory),
226+
"debug/truncated": sample.status == Sample.Status.TRUNCATED,
227+
}
228+
)
229+
230+
agent.cleanup()
231+
return sample
232+
233+
234+
async def reward_func(args, sample, **kwargs):
235+
"""Tool call reward function using math_dapo as primary reward model"""
236+
if not isinstance(sample, Sample):
237+
raise TypeError("Sample must be an instance of Sample class.")
238+
239+
# Extract information from sample
240+
solution_str = sample.response
241+
ground_truth = sample.label if sample.label is not None else ""
242+
tool_call_count = getattr(sample, "tool_call_count", 0)
243+
244+
# Accept both Answer: ... and \\boxed{...} answer
245+
result = math_dapo_compute_score(solution_str, ground_truth, strict_box_verify=False)
246+
result_boxed = math_dapo_compute_score(solution_str, ground_truth, strict_box_verify=True)
247+
if result["pred"] == "[INVALID]":
248+
result = result_boxed
249+
250+
# Encourage model to call tools
251+
if result["score"] < 0:
252+
tool_call_reward = (tool_call_count - 2) / 2 * 0.1
253+
result["score"] = min(-0.6, result["score"] + tool_call_reward)
254+
255+
if result["pred"] is None:
256+
result["pred"] = ""
257+
258+
logger.info(
259+
f"[Strands Agents] sample summary: "
260+
f"status={sample.status} | "
261+
f"tool_call_count={sample.tool_call_count} | "
262+
f"response_length={sample.response_length} | "
263+
f"reward={result} | "
264+
f"ground_truth={ground_truth}"
265+
)
266+
267+
return result
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
camel-ai
2+
strands-agents
3+
strands-agents-tools

0 commit comments

Comments
 (0)