1818
1919import importlib .util
2020from pathlib import Path
21+ from types import ModuleType
2122from typing import TYPE_CHECKING , Callable
2223
2324import click
2728from .config import EnvConfig , ModelConfig
2829
2930if TYPE_CHECKING :
31+ from strands_sglang .tool_parsers import ToolParser
32+
3033 from strands_env .eval import AsyncEnvFactory , Evaluator
3134
3235#: Type for the create_env_factory function exported by hook files.
3639EvaluatorClass = type ["Evaluator" ]
3740
3841
39- def build_model_factory (config : ModelConfig , max_concurrency : int ) -> ModelFactory :
40- """Build a ModelFactory from ModelConfig.
42+ # ---------------------------------------------------------------------------
43+ # Hook Loading
44+ # ---------------------------------------------------------------------------
4145
42- Args:
43- config: Model configuration.
44- max_concurrency: Max concurrent connections (for SGLang client pooling).
4546
46- Returns:
47- ModelFactory callable.
48- """
49- sampling = config .sampling .to_dict ()
47+ def _load_hook_module (path : Path , hook_name : str ) -> ModuleType :
48+ """Load a Python module from a file path.
5049
51- if config .backend == "sglang" :
52- from strands_env .utils .sglang import (
53- check_server_health ,
54- get_cached_client ,
55- get_cached_tokenizer ,
56- get_model_id ,
57- )
58-
59- # Check server health before proceeding
60- try :
61- check_server_health (config .base_url )
62- except ConnectionError as e :
63- raise click .ClickException (str (e ))
64-
65- client = get_cached_client (config .base_url , max_concurrency )
66-
67- # Resolve and backfill model_id/tokenizer_path for reproducibility
68- if not config .model_id :
69- config .model_id = get_model_id (config .base_url )
70- if not config .tokenizer_path :
71- config .tokenizer_path = config .model_id
72-
73- tokenizer = get_cached_tokenizer (config .tokenizer_path )
74- return sglang_model_factory (
75- client = client , model_id = config .model_id , tokenizer = tokenizer , sampling_params = sampling
76- )
50+ Args:
51+ path: Path to the Python file.
52+ hook_name: Name for the module (used in error messages).
7753
78- elif config . backend == "bedrock" :
79- from strands_env . utils . aws import get_assumed_role_session , get_boto3_session
54+ Returns :
55+ The loaded module.
8056
81- if not config .model_id :
82- raise click .ClickException ("--model-id is required for Bedrock backend" )
83- if config .role_arn :
84- boto_session = get_assumed_role_session (config .role_arn , config .region )
85- else :
86- boto_session = get_boto3_session (config .region , config .profile_name )
87- return bedrock_model_factory (model_id = config .model_id , boto_session = boto_session , sampling_params = sampling )
57+ Raises:
58+ click.ClickException: If the file cannot be loaded.
59+ """
60+ spec = importlib .util .spec_from_file_location (hook_name , path )
61+ if spec is None or spec .loader is None :
62+ raise click .ClickException (f"Could not load { hook_name } file: { path } " )
8863
89- else :
90- raise click .ClickException (f"Unknown backend: { config .backend } " )
64+ module = importlib .util .module_from_spec (spec )
65+ spec .loader .exec_module (module )
66+ return module
9167
9268
93- def load_env_hook (env_path : Path ) -> EnvFactoryCreator :
69+ def load_env_hook (path : Path ) -> EnvFactoryCreator :
9470 """Load environment hook file and return create_env_factory function.
9571
9672 The hook file must export a `create_env_factory(model_factory, env_config)` function.
9773
9874 Args:
99- env_path : Path to the Python hook file.
75+ path : Path to the Python hook file.
10076
10177 Returns:
10278 The create_env_factory function from the hook file.
10379
10480 Raises:
10581 click.ClickException: If the file cannot be loaded or doesn't export the function.
10682 """
107- spec = importlib .util .spec_from_file_location ("env_hook" , env_path )
108- if spec is None or spec .loader is None :
109- raise click .ClickException (f"Could not load hook file: { env_path } " )
83+ module = _load_hook_module (path , "env_hook" )
11084
111- hook = importlib .util .module_from_spec (spec )
112- spec .loader .exec_module (hook )
113-
114- if not hasattr (hook , "create_env_factory" ):
85+ if not hasattr (module , "create_env_factory" ):
11586 raise click .ClickException (
11687 "Hook file must export 'create_env_factory(model_factory, env_config)' function.\n "
11788 "Example:\n "
@@ -125,16 +96,16 @@ def load_env_hook(env_path: Path) -> EnvFactoryCreator:
12596 " return env_factory"
12697 )
12798
128- return hook .create_env_factory
99+ return module .create_env_factory
129100
130101
131- def load_evaluator_hook (evaluator_path : Path ) -> EvaluatorClass :
102+ def load_evaluator_hook (path : Path ) -> EvaluatorClass :
132103 """Load evaluator hook file and return the Evaluator class.
133104
134105 The hook file must export an `EvaluatorClass` that extends `Evaluator`.
135106
136107 Args:
137- evaluator_path : Path to the Python hook file.
108+ path : Path to the Python hook file.
138109
139110 Returns:
140111 The Evaluator subclass from the hook file.
@@ -144,14 +115,9 @@ def load_evaluator_hook(evaluator_path: Path) -> EvaluatorClass:
144115 """
145116 from strands_env .eval import Evaluator
146117
147- spec = importlib .util .spec_from_file_location ("evaluator_hook" , evaluator_path )
148- if spec is None or spec .loader is None :
149- raise click .ClickException (f"Could not load evaluator hook file: { evaluator_path } " )
150-
151- hook = importlib .util .module_from_spec (spec )
152- spec .loader .exec_module (hook )
118+ module = _load_hook_module (path , "evaluator_hook" )
153119
154- if not hasattr (hook , "EvaluatorClass" ):
120+ if not hasattr (module , "EvaluatorClass" ):
155121 raise click .ClickException (
156122 "Evaluator hook file must export 'EvaluatorClass' (an Evaluator subclass).\n "
157123 "Example:\n "
@@ -166,8 +132,161 @@ def load_evaluator_hook(evaluator_path: Path) -> EvaluatorClass:
166132 " EvaluatorClass = MyEvaluator"
167133 )
168134
169- evaluator_cls = hook .EvaluatorClass
135+ evaluator_cls = module .EvaluatorClass
170136 if not isinstance (evaluator_cls , type ) or not issubclass (evaluator_cls , Evaluator ):
171137 raise click .ClickException ("EvaluatorClass must be a subclass of Evaluator" )
172138
173139 return evaluator_cls
140+
141+
142+ def load_tool_parser (tool_parser_arg : str | None ) -> ToolParser | None :
143+ """Load tool parser from name or hook file path.
144+
145+ Args:
146+ tool_parser_arg: Either a parser name (e.g., "hermes", "qwen_xml") or path to hook file.
147+
148+ Returns:
149+ ToolParser instance, or None if not specified.
150+
151+ Raises:
152+ click.ClickException: If the parser name is unknown or hook file is invalid.
153+ """
154+ if tool_parser_arg is None :
155+ return None
156+
157+ # Check if it's a file path
158+ path = Path (tool_parser_arg )
159+ if path .exists () and path .is_file ():
160+ return _load_tool_parser_hook (path )
161+
162+ # Otherwise treat as parser name
163+ from strands_sglang .tool_parsers import get_tool_parser
164+
165+ try :
166+ return get_tool_parser (tool_parser_arg )
167+ except KeyError as e :
168+ raise click .ClickException (str (e ))
169+
170+
171+ def _load_tool_parser_hook (path : Path ) -> ToolParser :
172+ """Load tool parser from hook file.
173+
174+ The hook file must export either:
175+ - `tool_parser`: A ToolParser instance
176+ - `ToolParserClass`: A ToolParser subclass (will be instantiated)
177+
178+ Args:
179+ path: Path to the Python hook file.
180+
181+ Returns:
182+ ToolParser instance from the hook file.
183+
184+ Raises:
185+ click.ClickException: If the file cannot be loaded or doesn't export the parser.
186+ """
187+ from strands_sglang .tool_parsers import ToolParser
188+
189+ module = _load_hook_module (path , "tool_parser_hook" )
190+
191+ # Check for tool_parser instance first
192+ if hasattr (module , "tool_parser" ):
193+ parser = module .tool_parser
194+ if not isinstance (parser , ToolParser ):
195+ raise click .ClickException ("'tool_parser' must be a ToolParser instance" )
196+ return parser
197+
198+ # Check for ToolParserClass
199+ if hasattr (module , "ToolParserClass" ):
200+ parser_cls = module .ToolParserClass
201+ if not isinstance (parser_cls , type ) or not issubclass (parser_cls , ToolParser ):
202+ raise click .ClickException ("'ToolParserClass' must be a ToolParser subclass" )
203+ return parser_cls ()
204+
205+ raise click .ClickException (
206+ "Tool parser hook file must export 'tool_parser' (instance) or 'ToolParserClass' (subclass).\n "
207+ "Example:\n "
208+ " from strands_sglang.tool_parsers import ToolParser, ToolParseResult\n "
209+ "\n "
210+ " class MyToolParser(ToolParser):\n "
211+ " def parse(self, text: str) -> list[ToolParseResult]:\n "
212+ " ...\n "
213+ "\n "
214+ " tool_parser = MyToolParser()\n "
215+ " # OR\n "
216+ " ToolParserClass = MyToolParser"
217+ )
218+
219+
220+ # ---------------------------------------------------------------------------
221+ # Model Factory
222+ # ---------------------------------------------------------------------------
223+
224+
225+ def build_model_factory (config : ModelConfig , max_concurrency : int ) -> ModelFactory :
226+ """Build a ModelFactory from ModelConfig.
227+
228+ Args:
229+ config: Model configuration.
230+ max_concurrency: Max concurrent connections (for SGLang client pooling).
231+
232+ Returns:
233+ ModelFactory callable.
234+ """
235+ sampling = config .sampling .to_dict ()
236+
237+ if config .backend == "sglang" :
238+ return _build_sglang_model_factory (config , max_concurrency , sampling )
239+ elif config .backend == "bedrock" :
240+ return _build_bedrock_model_factory (config , sampling )
241+ else :
242+ raise click .ClickException (f"Unknown backend: { config .backend } " )
243+
244+
245+ def _build_sglang_model_factory (config : ModelConfig , max_concurrency : int , sampling : dict ) -> ModelFactory :
246+ """Build SGLang model factory."""
247+ from strands_env .utils .sglang import (
248+ check_server_health ,
249+ get_cached_client ,
250+ get_cached_tokenizer ,
251+ get_model_id ,
252+ )
253+
254+ # Check server health before proceeding
255+ try :
256+ check_server_health (config .base_url )
257+ except ConnectionError as e :
258+ raise click .ClickException (str (e ))
259+
260+ client = get_cached_client (config .base_url , max_concurrency )
261+
262+ # Resolve and backfill model_id/tokenizer_path for reproducibility
263+ if not config .model_id :
264+ config .model_id = get_model_id (config .base_url )
265+ if not config .tokenizer_path :
266+ config .tokenizer_path = config .model_id
267+
268+ tokenizer = get_cached_tokenizer (config .tokenizer_path )
269+ tool_parser = load_tool_parser (config .tool_parser )
270+
271+ return sglang_model_factory (
272+ client = client ,
273+ model_id = config .model_id ,
274+ tokenizer = tokenizer ,
275+ tool_parser = tool_parser ,
276+ sampling_params = sampling ,
277+ )
278+
279+
280+ def _build_bedrock_model_factory (config : ModelConfig , sampling : dict ) -> ModelFactory :
281+ """Build Bedrock model factory."""
282+ from strands_env .utils .aws import get_assumed_role_session , get_boto3_session
283+
284+ if not config .model_id :
285+ raise click .ClickException ("--model-id is required for Bedrock backend" )
286+
287+ if config .role_arn :
288+ boto_session = get_assumed_role_session (config .role_arn , config .region )
289+ else :
290+ boto_session = get_boto3_session (config .region , config .profile_name )
291+
292+ return bedrock_model_factory (model_id = config .model_id , boto_session = boto_session , sampling_params = sampling )
0 commit comments