22
33from camel .interpreters import SubprocessInterpreter
44from strands import Agent , tool
5- from strands_sglang import SGLangClient , SGLangModel
6- from strands_sglang .tool_limiter import ToolIterationLimiter
5+ from strands_sglang import SGLangClient , SGLangModel , ToolLimiter
6+ from strands_sglang .tool_parsers import HermesToolParser
77
88from slime .rollout .rm_hub .math_dapo_utils import compute_score as math_dapo_compute_score
99from slime .rollout .sglang_rollout import GenerateState
2020- After completing your reasoning, present the final result enclosed in \\ boxed{}.
2121""" .strip ()
2222
23- MAX_TOOL_ITERATIONS = 5
23+ MAX_TOOL_ITERS = 5
24+ MAX_TOOL_CALLS = None # No limit
2425
2526_client_cache : dict [str , SGLangClient ] = {}
2627
2728
2829def get_client (args ) -> SGLangClient :
29- """Get shared client for connection pooling (like SLIME) ."""
30+ """Get shared client for connection pooling."""
3031 base_url = f"http://{ args .sglang_router_ip } :{ args .sglang_router_port } "
3132 if base_url not in _client_cache :
32- _client_cache [base_url ] = SGLangClient .from_slime_args (args )
33+ _client_cache [base_url ] = SGLangClient .from_slime_args (args , timeout = 300.0 )
3334 return _client_cache [base_url ]
3435
3536
@@ -55,15 +56,15 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
5556 model = SGLangModel (
5657 tokenizer = state .tokenizer ,
5758 client = get_client (args ),
58- model_id = args . hf_checkpoint . split ( "/" )[ - 1 ],
59- params = { k : sampling_params [ k ] for k in [ "max_new_tokens" , "temperature" , "top_p" ]} ,
59+ tool_parser = HermesToolParser (), # tool parsing for wrapped JSON tool calls
60+ sampling_params = sampling_params ,
6061 )
6162
62- limiter = ToolIterationLimiter ( max_iterations = MAX_TOOL_ITERATIONS )
63+ tool_limiter = ToolLimiter ( max_tool_iters = MAX_TOOL_ITERS , max_tool_calls = MAX_TOOL_CALLS )
6364 agent = Agent (
6465 model = model ,
6566 tools = [execute_python_code ],
66- hooks = [limiter ],
67+ hooks = [tool_limiter ],
6768 callback_handler = None ,
6869 system_prompt = SYSTEM_PROMPT ,
6970 )
@@ -74,12 +75,12 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
7475 await agent .invoke_async (prompt )
7576 sample .status = Sample .Status .COMPLETED
7677 except Exception as e :
77- # Always use TRUNCATED instead of ABORTED because Slime doesn't properly
78+ # Always use TRUNCATED instead of ABORTED because slime doesn't properly
7879 # handle ABORTED samples in reward processing. See: https://github.com/THUDM/slime/issues/200
7980 sample .status = Sample .Status .TRUNCATED
8081 logger .warning (f"TRUNCATED: { type (e ).__name__ } : { e } " )
8182
82- # TITO: extract trajectory from token_manager
83+ # Extract token trajectory from token_manager
8384 tm = model .token_manager
8485 prompt_len = len (tm .segments [0 ]) # system + user are first segment
8586 sample .tokens = tm .token_ids
@@ -88,9 +89,8 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
8889 sample .response_length = len (sample .tokens ) - prompt_len
8990 sample .response = model .tokenizer .decode (sample .tokens [prompt_len :], skip_special_tokens = False )
9091 # Tool iteration and tool call count are different because multiple parallel tool calls count as 1 iteration
91- sample .tool_iterations = limiter .iteration_count
92- trajectory = model .format_request_messages (agent .messages , None )
93- sample .tool_call_count = [message ["role" ] == "tool" for message in trajectory ].count (True )
92+ sample .tool_iters = tool_limiter .tool_iter_count
93+ sample .tool_calls = tool_limiter .tool_call_count
9494
9595 model .reset ()
9696 agent .cleanup ()
@@ -100,18 +100,19 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
100100async def reward_func (args , sample : Sample , ** kwargs ):
101101 """Reward function using math_dapo scoring."""
102102 ground_truth = sample .label or ""
103- tool_iterations = getattr (sample , "tool_iterations" , 0 )
103+ tool_iters = getattr (sample , "tool_iters" , 0 )
104+ tool_calls = getattr (sample , "tool_calls" , 0 )
104105
105106 result = math_dapo_compute_score (sample .response , ground_truth , strict_box_verify = False )
106107 if result ["pred" ] == "[INVALID]" :
107108 result = math_dapo_compute_score (sample .response , ground_truth , strict_box_verify = True )
108109
109110 # Encourage tool use on failures
110111 if result ["score" ] < 0 :
111- result ["score" ] = min (- 0.6 , result ["score" ] + (tool_iterations - 2 ) / 2 * 0.1 )
112+ result ["score" ] = min (- 0.6 , result ["score" ] + (tool_iters - 2 ) / 2 * 0.1 )
112113
113114 result ["pred" ] = result ["pred" ] or ""
114115 logger .info (
115- f"reward={ result ['score' ]:.2f} | status={ sample .status .name } | tool_iters={ tool_iterations } | tool_calls={ getattr ( sample , 'tool_call_count' , 0 ) } | tokens={ len (sample .tokens )} | resp_len={ sample .response_length } | "
116+ f"reward={ result ['score' ]:.2f} | status={ sample .status .name } | tool_iters={ tool_iters } | tool_calls={ tool_calls } | tokens={ len (sample .tokens )} | resp_len={ sample .response_length } | "
116117 )
117118 return result ["score" ]
0 commit comments