Skip to content

Commit 88b8fbc

Browse files
Lawhyclaude
andcommitted
feat(cli): add --tool-parser option for custom tool parsers
Add CLI option to specify tool parser by name (e.g., 'hermes', 'qwen_xml') or path to a hook file exporting a custom ToolParser. Also refactor CLI utils: - Extract _load_hook_module() helper to reduce duplication - Move ToolParser import to TYPE_CHECKING to avoid unnecessary deps - Split build_model_factory() into backend-specific helpers - Add tool_parser field to ModelConfig for proper serialization Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent bfd929e commit 88b8fbc

File tree

4 files changed

+200
-67
lines changed

4 files changed

+200
-67
lines changed

src/strands_env/cli/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ def list_cmd():
110110
default=None,
111111
help="AWS role ARN for Bedrock (optional).",
112112
)
113+
@click.option(
114+
"--tool-parser",
115+
type=str,
116+
default=None,
117+
help="Tool parser: name (e.g., 'hermes', 'qwen_xml') or path to hook file.",
118+
)
113119
# Sampling params
114120
@click.option(
115121
"--temperature",
@@ -199,6 +205,7 @@ def eval_cmd(
199205
region: str | None,
200206
profile_name: str | None,
201207
role_arn: str | None,
208+
tool_parser: str | None,
202209
# Sampling
203210
temperature: float,
204211
max_tokens: int,
@@ -260,6 +267,7 @@ def eval_cmd(
260267
base_url=base_url,
261268
model_id=model_id,
262269
tokenizer_path=tokenizer_path,
270+
tool_parser=tool_parser,
263271
region=region,
264272
profile_name=profile_name,
265273
role_arn=role_arn,

src/strands_env/cli/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class ModelConfig:
5151
# SGLang
5252
base_url: str = "http://localhost:30000"
5353
tokenizer_path: str | None = None # Auto-detected if None
54+
tool_parser: str | None = None # Parser name or path to hook file
5455

5556
# Bedrock
5657
model_id: str | None = None
@@ -67,6 +68,7 @@ def to_dict(self) -> dict:
6768
"backend": self.backend,
6869
"base_url": self.base_url,
6970
"tokenizer_path": self.tokenizer_path,
71+
"tool_parser": self.tool_parser,
7072
"model_id": self.model_id,
7173
"region": self.region,
7274
"profile_name": self.profile_name,

src/strands_env/cli/utils.py

Lines changed: 185 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import importlib.util
2020
from pathlib import Path
21+
from types import ModuleType
2122
from typing import TYPE_CHECKING, Callable
2223

2324
import click
@@ -27,6 +28,8 @@
2728
from .config import EnvConfig, ModelConfig
2829

2930
if TYPE_CHECKING:
31+
from strands_sglang.tool_parsers import ToolParser
32+
3033
from strands_env.eval import AsyncEnvFactory, Evaluator
3134

3235
#: Type for the create_env_factory function exported by hook files.
@@ -36,82 +39,50 @@
3639
EvaluatorClass = type["Evaluator"]
3740

3841

39-
def build_model_factory(config: ModelConfig, max_concurrency: int) -> ModelFactory:
40-
"""Build a ModelFactory from ModelConfig.
42+
# ---------------------------------------------------------------------------
43+
# Hook Loading
44+
# ---------------------------------------------------------------------------
4145

42-
Args:
43-
config: Model configuration.
44-
max_concurrency: Max concurrent connections (for SGLang client pooling).
4546

46-
Returns:
47-
ModelFactory callable.
48-
"""
49-
sampling = config.sampling.to_dict()
47+
def _load_hook_module(path: Path, hook_name: str) -> ModuleType:
48+
"""Load a Python module from a file path.
5049
51-
if config.backend == "sglang":
52-
from strands_env.utils.sglang import (
53-
check_server_health,
54-
get_cached_client,
55-
get_cached_tokenizer,
56-
get_model_id,
57-
)
58-
59-
# Check server health before proceeding
60-
try:
61-
check_server_health(config.base_url)
62-
except ConnectionError as e:
63-
raise click.ClickException(str(e))
64-
65-
client = get_cached_client(config.base_url, max_concurrency)
66-
67-
# Resolve and backfill model_id/tokenizer_path for reproducibility
68-
if not config.model_id:
69-
config.model_id = get_model_id(config.base_url)
70-
if not config.tokenizer_path:
71-
config.tokenizer_path = config.model_id
72-
73-
tokenizer = get_cached_tokenizer(config.tokenizer_path)
74-
return sglang_model_factory(
75-
client=client, model_id=config.model_id, tokenizer=tokenizer, sampling_params=sampling
76-
)
50+
Args:
51+
path: Path to the Python file.
52+
hook_name: Name for the module (used in error messages).
7753
78-
elif config.backend == "bedrock":
79-
from strands_env.utils.aws import get_assumed_role_session, get_boto3_session
54+
Returns:
55+
The loaded module.
8056
81-
if not config.model_id:
82-
raise click.ClickException("--model-id is required for Bedrock backend")
83-
if config.role_arn:
84-
boto_session = get_assumed_role_session(config.role_arn, config.region)
85-
else:
86-
boto_session = get_boto3_session(config.region, config.profile_name)
87-
return bedrock_model_factory(model_id=config.model_id, boto_session=boto_session, sampling_params=sampling)
57+
Raises:
58+
click.ClickException: If the file cannot be loaded.
59+
"""
60+
spec = importlib.util.spec_from_file_location(hook_name, path)
61+
if spec is None or spec.loader is None:
62+
raise click.ClickException(f"Could not load {hook_name} file: {path}")
8863

89-
else:
90-
raise click.ClickException(f"Unknown backend: {config.backend}")
64+
module = importlib.util.module_from_spec(spec)
65+
spec.loader.exec_module(module)
66+
return module
9167

9268

93-
def load_env_hook(env_path: Path) -> EnvFactoryCreator:
69+
def load_env_hook(path: Path) -> EnvFactoryCreator:
9470
"""Load environment hook file and return create_env_factory function.
9571
9672
The hook file must export a `create_env_factory(model_factory, env_config)` function.
9773
9874
Args:
99-
env_path: Path to the Python hook file.
75+
path: Path to the Python hook file.
10076
10177
Returns:
10278
The create_env_factory function from the hook file.
10379
10480
Raises:
10581
click.ClickException: If the file cannot be loaded or doesn't export the function.
10682
"""
107-
spec = importlib.util.spec_from_file_location("env_hook", env_path)
108-
if spec is None or spec.loader is None:
109-
raise click.ClickException(f"Could not load hook file: {env_path}")
83+
module = _load_hook_module(path, "env_hook")
11084

111-
hook = importlib.util.module_from_spec(spec)
112-
spec.loader.exec_module(hook)
113-
114-
if not hasattr(hook, "create_env_factory"):
85+
if not hasattr(module, "create_env_factory"):
11586
raise click.ClickException(
11687
"Hook file must export 'create_env_factory(model_factory, env_config)' function.\n"
11788
"Example:\n"
@@ -125,16 +96,16 @@ def load_env_hook(env_path: Path) -> EnvFactoryCreator:
12596
" return env_factory"
12697
)
12798

128-
return hook.create_env_factory
99+
return module.create_env_factory
129100

130101

131-
def load_evaluator_hook(evaluator_path: Path) -> EvaluatorClass:
102+
def load_evaluator_hook(path: Path) -> EvaluatorClass:
132103
"""Load evaluator hook file and return the Evaluator class.
133104
134105
The hook file must export an `EvaluatorClass` that extends `Evaluator`.
135106
136107
Args:
137-
evaluator_path: Path to the Python hook file.
108+
path: Path to the Python hook file.
138109
139110
Returns:
140111
The Evaluator subclass from the hook file.
@@ -144,14 +115,9 @@ def load_evaluator_hook(evaluator_path: Path) -> EvaluatorClass:
144115
"""
145116
from strands_env.eval import Evaluator
146117

147-
spec = importlib.util.spec_from_file_location("evaluator_hook", evaluator_path)
148-
if spec is None or spec.loader is None:
149-
raise click.ClickException(f"Could not load evaluator hook file: {evaluator_path}")
150-
151-
hook = importlib.util.module_from_spec(spec)
152-
spec.loader.exec_module(hook)
118+
module = _load_hook_module(path, "evaluator_hook")
153119

154-
if not hasattr(hook, "EvaluatorClass"):
120+
if not hasattr(module, "EvaluatorClass"):
155121
raise click.ClickException(
156122
"Evaluator hook file must export 'EvaluatorClass' (an Evaluator subclass).\n"
157123
"Example:\n"
@@ -166,8 +132,161 @@ def load_evaluator_hook(evaluator_path: Path) -> EvaluatorClass:
166132
" EvaluatorClass = MyEvaluator"
167133
)
168134

169-
evaluator_cls = hook.EvaluatorClass
135+
evaluator_cls = module.EvaluatorClass
170136
if not isinstance(evaluator_cls, type) or not issubclass(evaluator_cls, Evaluator):
171137
raise click.ClickException("EvaluatorClass must be a subclass of Evaluator")
172138

173139
return evaluator_cls
140+
141+
142+
def load_tool_parser(tool_parser_arg: str | None) -> ToolParser | None:
143+
"""Load tool parser from name or hook file path.
144+
145+
Args:
146+
tool_parser_arg: Either a parser name (e.g., "hermes", "qwen_xml") or path to hook file.
147+
148+
Returns:
149+
ToolParser instance, or None if not specified.
150+
151+
Raises:
152+
click.ClickException: If the parser name is unknown or hook file is invalid.
153+
"""
154+
if tool_parser_arg is None:
155+
return None
156+
157+
# Check if it's a file path
158+
path = Path(tool_parser_arg)
159+
if path.exists() and path.is_file():
160+
return _load_tool_parser_hook(path)
161+
162+
# Otherwise treat as parser name
163+
from strands_sglang.tool_parsers import get_tool_parser
164+
165+
try:
166+
return get_tool_parser(tool_parser_arg)
167+
except KeyError as e:
168+
raise click.ClickException(str(e))
169+
170+
171+
def _load_tool_parser_hook(path: Path) -> ToolParser:
172+
"""Load tool parser from hook file.
173+
174+
The hook file must export either:
175+
- `tool_parser`: A ToolParser instance
176+
- `ToolParserClass`: A ToolParser subclass (will be instantiated)
177+
178+
Args:
179+
path: Path to the Python hook file.
180+
181+
Returns:
182+
ToolParser instance from the hook file.
183+
184+
Raises:
185+
click.ClickException: If the file cannot be loaded or doesn't export the parser.
186+
"""
187+
from strands_sglang.tool_parsers import ToolParser
188+
189+
module = _load_hook_module(path, "tool_parser_hook")
190+
191+
# Check for tool_parser instance first
192+
if hasattr(module, "tool_parser"):
193+
parser = module.tool_parser
194+
if not isinstance(parser, ToolParser):
195+
raise click.ClickException("'tool_parser' must be a ToolParser instance")
196+
return parser
197+
198+
# Check for ToolParserClass
199+
if hasattr(module, "ToolParserClass"):
200+
parser_cls = module.ToolParserClass
201+
if not isinstance(parser_cls, type) or not issubclass(parser_cls, ToolParser):
202+
raise click.ClickException("'ToolParserClass' must be a ToolParser subclass")
203+
return parser_cls()
204+
205+
raise click.ClickException(
206+
"Tool parser hook file must export 'tool_parser' (instance) or 'ToolParserClass' (subclass).\n"
207+
"Example:\n"
208+
" from strands_sglang.tool_parsers import ToolParser, ToolParseResult\n"
209+
"\n"
210+
" class MyToolParser(ToolParser):\n"
211+
" def parse(self, text: str) -> list[ToolParseResult]:\n"
212+
" ...\n"
213+
"\n"
214+
" tool_parser = MyToolParser()\n"
215+
" # OR\n"
216+
" ToolParserClass = MyToolParser"
217+
)
218+
219+
220+
# ---------------------------------------------------------------------------
221+
# Model Factory
222+
# ---------------------------------------------------------------------------
223+
224+
225+
def build_model_factory(config: ModelConfig, max_concurrency: int) -> ModelFactory:
226+
"""Build a ModelFactory from ModelConfig.
227+
228+
Args:
229+
config: Model configuration.
230+
max_concurrency: Max concurrent connections (for SGLang client pooling).
231+
232+
Returns:
233+
ModelFactory callable.
234+
"""
235+
sampling = config.sampling.to_dict()
236+
237+
if config.backend == "sglang":
238+
return _build_sglang_model_factory(config, max_concurrency, sampling)
239+
elif config.backend == "bedrock":
240+
return _build_bedrock_model_factory(config, sampling)
241+
else:
242+
raise click.ClickException(f"Unknown backend: {config.backend}")
243+
244+
245+
def _build_sglang_model_factory(config: ModelConfig, max_concurrency: int, sampling: dict) -> ModelFactory:
246+
"""Build SGLang model factory."""
247+
from strands_env.utils.sglang import (
248+
check_server_health,
249+
get_cached_client,
250+
get_cached_tokenizer,
251+
get_model_id,
252+
)
253+
254+
# Check server health before proceeding
255+
try:
256+
check_server_health(config.base_url)
257+
except ConnectionError as e:
258+
raise click.ClickException(str(e))
259+
260+
client = get_cached_client(config.base_url, max_concurrency)
261+
262+
# Resolve and backfill model_id/tokenizer_path for reproducibility
263+
if not config.model_id:
264+
config.model_id = get_model_id(config.base_url)
265+
if not config.tokenizer_path:
266+
config.tokenizer_path = config.model_id
267+
268+
tokenizer = get_cached_tokenizer(config.tokenizer_path)
269+
tool_parser = load_tool_parser(config.tool_parser)
270+
271+
return sglang_model_factory(
272+
client=client,
273+
model_id=config.model_id,
274+
tokenizer=tokenizer,
275+
tool_parser=tool_parser,
276+
sampling_params=sampling,
277+
)
278+
279+
280+
def _build_bedrock_model_factory(config: ModelConfig, sampling: dict) -> ModelFactory:
281+
"""Build Bedrock model factory."""
282+
from strands_env.utils.aws import get_assumed_role_session, get_boto3_session
283+
284+
if not config.model_id:
285+
raise click.ClickException("--model-id is required for Bedrock backend")
286+
287+
if config.role_arn:
288+
boto_session = get_assumed_role_session(config.role_arn, config.region)
289+
else:
290+
boto_session = get_boto3_session(config.region, config.profile_name)
291+
292+
return bedrock_model_factory(model_id=config.model_id, boto_session=boto_session, sampling_params=sampling)

src/strands_env/core/models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def sglang_model_factory(
6767
model_id: str,
6868
tokenizer: PreTrainedTokenizerBase,
6969
client: SGLangClient,
70-
tool_parser: ToolParser = HermesToolParser(),
70+
tool_parser: ToolParser | None = None,
7171
sampling_params: dict[str, Any] = DEFAULT_SAMPLING_PARAMS,
7272
enable_thinking: bool | None = None,
7373
) -> ModelFactory:
@@ -77,9 +77,13 @@ def sglang_model_factory(
7777
model_id: SGLang model identifier.
7878
tokenizer: HuggingFace tokenizer for chat template and tokenization.
7979
client: `SGLangClient` for HTTP communication with the SGLang server.
80+
tool_parser: Tool parser for extracting tool calls from model output. Defaults to `HermesToolParser`.
8081
sampling_params: Sampling parameters for the model (e.g. `{"max_new_tokens": 4096}`).
8182
enable_thinking: Enable thinking mode for Qwen3 hybrid models.
8283
"""
84+
if tool_parser is None:
85+
tool_parser = HermesToolParser()
86+
8387
return lambda: SGLangModel(
8488
tokenizer=tokenizer,
8589
client=client,

0 commit comments

Comments
 (0)