|
22 | 22 | RetryPromptPart, |
23 | 23 | ) |
24 | 24 | from pydantic_ai.models import Model |
25 | | -from pydantic_ai.toolsets import FunctionToolset, ToolsetTool |
| 25 | +from pydantic_ai.settings import ModelSettings |
| 26 | +from pydantic_ai.toolsets import FunctionToolset |
26 | 27 |
|
27 | 28 | from haiku.skills.models import Skill |
28 | 29 | from haiku.skills.prompts import SKILL_PROMPT |
@@ -139,10 +140,22 @@ def _discover_scripts(skill: Skill) -> list[str]: |
139 | 140 | ) |
140 | 141 |
|
141 | 142 |
|
142 | | -def _create_run_script(skill: Skill) -> Callable[..., Any]: |
| 143 | +SCRIPT_TIMEOUT_DEFAULT = 120.0 |
| 144 | + |
| 145 | + |
| 146 | +def _create_run_script( |
| 147 | + skill: Skill, timeout: float | None = None |
| 148 | +) -> Callable[..., Any]: |
143 | 149 | """Create a run_script tool bound to a specific skill.""" |
144 | 150 | assert skill.path is not None |
145 | 151 | scripts_dir = (skill.path / "scripts").resolve() |
| 152 | + resolved_timeout = ( |
| 153 | + timeout |
| 154 | + if timeout is not None |
| 155 | + else float( |
| 156 | + os.environ.get("HAIKU_SKILLS_SCRIPT_TIMEOUT", SCRIPT_TIMEOUT_DEFAULT) |
| 157 | + ) |
| 158 | + ) |
146 | 159 |
|
147 | 160 | async def run_script(script: str, arguments: str = "") -> str: |
148 | 161 | """Execute a script from the skill's scripts/ directory. |
@@ -171,7 +184,14 @@ async def run_script(script: str, arguments: str = "") -> str: |
171 | 184 | stdout=asyncio.subprocess.PIPE, |
172 | 185 | stderr=asyncio.subprocess.PIPE, |
173 | 186 | ) |
174 | | - stdout, stderr = await proc.communicate() |
| 187 | + try: |
| 188 | + stdout, stderr = await asyncio.wait_for( |
| 189 | + proc.communicate(), timeout=resolved_timeout |
| 190 | + ) |
| 191 | + except TimeoutError: |
| 192 | + proc.kill() |
| 193 | + await proc.wait() |
| 194 | + raise RuntimeError(f"Script {script} timed out after {resolved_timeout}s") |
175 | 195 | if proc.returncode != 0: |
176 | 196 | output = stderr.decode().strip() or stdout.decode().strip() |
177 | 197 | raise RuntimeError( |
@@ -249,11 +269,15 @@ async def event_handler( |
249 | 269 | tools=tools, |
250 | 270 | toolsets=skill.toolsets or None, |
251 | 271 | ) |
| 272 | + model_settings = ( |
| 273 | + ModelSettings(thinking=skill.thinking) if skill.thinking is not None else None |
| 274 | + ) |
252 | 275 | result = await agent.run( |
253 | 276 | request, |
254 | 277 | deps=deps, |
255 | 278 | usage_limits=UsageLimits(request_limit=20), |
256 | 279 | event_stream_handler=event_handler, |
| 280 | + model_settings=model_settings, |
257 | 281 | ) |
258 | 282 | text = result.output |
259 | 283 | return text, collected_events, emitted_events |
@@ -308,30 +332,27 @@ def _register_skill_state(self, skill: Skill) -> None: |
308 | 332 | else: |
309 | 333 | self._namespaces[namespace] = skill.state_type() |
310 | 334 |
|
311 | | - async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: |
312 | | - # Overridden to restore AG-UI state from deps before returning tools. |
313 | | - # get_tools() is the only per-run hook with RunContext access in the |
314 | | - # toolset API — there is no dedicated per-run setup method. |
315 | | - self._maybe_restore_state(ctx) |
316 | | - return await super().get_tools(ctx) |
317 | | - |
318 | | - def _maybe_restore_state(self, ctx: RunContext[Any]) -> None: |
319 | | - """Restore namespace state from deps if it carries AG-UI state. |
| 335 | + async def for_run(self, ctx: RunContext[Any]) -> "SkillToolset": |
| 336 | + """Restore AG-UI state from deps before the run starts. |
320 | 337 |
|
321 | 338 | Uses identity check (``is``) so we restore once per AG-UI request |
322 | | - (each request creates a new dict) but not on every model step within |
323 | | - a single run. |
| 339 | + (each request creates a new dict) but not redundantly within a run. |
324 | 340 | """ |
325 | 341 | deps = ctx.deps |
326 | | - if deps is None or not hasattr(deps, "state"): |
327 | | - return |
328 | | - state = deps.state |
329 | | - if not isinstance(state, dict) or not state: |
330 | | - return |
331 | | - if state is self._last_restored_state: |
332 | | - return |
333 | | - self._last_restored_state = state |
334 | | - self.restore_state_snapshot(state) |
| 342 | + if deps is not None and hasattr(deps, "state"): |
| 343 | + state = deps.state |
| 344 | + if ( |
| 345 | + isinstance(state, dict) |
| 346 | + and state |
| 347 | + and state is not self._last_restored_state |
| 348 | + ): |
| 349 | + self._last_restored_state = state |
| 350 | + self.restore_state_snapshot(state) |
| 351 | + return self |
| 352 | + |
| 353 | + @property |
| 354 | + def use_subagents(self) -> bool: |
| 355 | + return self._use_subagents |
335 | 356 |
|
336 | 357 | @property |
337 | 358 | def registry(self) -> SkillRegistry: |
|
0 commit comments