diff --git a/strands-py/strands/_wasm_host.py b/strands-py/strands/_wasm_host.py index 696b685b5..ac6448044 100644 --- a/strands-py/strands/_wasm_host.py +++ b/strands-py/strands/_wasm_host.py @@ -101,11 +101,13 @@ def _run_sync(coro: typing.Coroutine) -> typing.Any: # Already inside a running loop — run in a fresh thread to avoid nesting result = [None] exc = [None] + def _target(): try: result[0] = asyncio.run(coro) except Exception as e: exc[0] = e + t = threading.Thread(target=_target) t.start() t.join() @@ -169,6 +171,7 @@ def _get_engine_and_component() -> tuple[Engine, Component]: # Record / Variant builders (Python → WIT kebab-case) # --------------------------------------------------------------------------- + def _rec(**kwargs: typing.Any) -> Record: """Build a wasmtime-py Record with the given kebab-case fields.""" r = Record.__new__(Record) @@ -178,7 +181,9 @@ def _rec(**kwargs: typing.Any) -> Record: def _build_tool_spec(ts: ToolSpec) -> Record: - return _rec(name=ts.name, description=ts.description, **{"input-schema": ts.input_schema}) + return _rec( + name=ts.name, description=ts.description, **{"input-schema": ts.input_schema} + ) def _build_model_config_variant(cfg: ModelConfigInput) -> Variant: @@ -266,7 +271,9 @@ def _build_conversation_manager_variant( "should-truncate-results": False, "summary-ratio": config.get("summary_ratio"), "preserve-recent-messages": config.get("preserve_recent_messages"), - "summarization-system-prompt": config.get("summarization_system_prompt"), + "summarization-system-prompt": config.get( + "summarization_system_prompt" + ), "summarization-model-config": config.get("summarization_model_config"), }, ) @@ -319,6 +326,7 @@ def _build_stream_args( # Variant → flat StreamEvent converters (WIT → Python types) # --------------------------------------------------------------------------- + def _opt_attr(rec: typing.Any, name: str) -> typing.Any: """Read an optional attribute from a wasmtime Record (kebab-case).""" return getattr(rec, name, None) if rec is not None else None @@ -413,6 +421,7 @@ def _convert_stream_event(v: Variant) -> StreamEvent: # AWS credential injection # --------------------------------------------------------------------------- + def _resolve_aws_credentials() -> tuple[str, str, str | None] | None: key_id = os.environ.get("AWS_ACCESS_KEY_ID") secret = os.environ.get("AWS_SECRET_ACCESS_KEY") @@ -485,7 +494,10 @@ def _inject_aws_credentials_default() -> Variant | None: # Import callback factories # --------------------------------------------------------------------------- -def _make_call_tool_fn(dispatcher: ToolDispatcherBase | None) -> typing.Callable[..., typing.Any]: + +def _make_call_tool_fn( + dispatcher: ToolDispatcherBase | None, +) -> typing.Callable[..., typing.Any]: def call_tool(store_ctx: typing.Any, args: typing.Any) -> Variant: name = getattr(args, "name") input_json = getattr(args, "input") @@ -497,10 +509,13 @@ def call_tool(store_ctx: typing.Any, args: typing.Any) -> Variant: return Variant("ok", result) except Exception as exc: return Variant("err", str(exc)) + return call_tool -def _make_call_tools_fn(dispatcher: ToolDispatcherBase | None) -> typing.Callable[..., typing.Any]: +def _make_call_tools_fn( + dispatcher: ToolDispatcherBase | None, +) -> typing.Callable[..., typing.Any]: def call_tools(store_ctx: typing.Any, args: typing.Any) -> list[Variant]: calls = getattr(args, "calls") results: list[Variant] = [] @@ -517,6 +532,7 @@ def call_tools(store_ctx: typing.Any, args: typing.Any) -> list[Variant]: except Exception as exc: results.append(Variant("err", str(exc))) return results + return call_tools @@ -529,18 +545,95 @@ def log_fn(store_ctx: typing.Any, entry: typing.Any) -> None: handler.log(level, message, context) else: logger = logging.getLogger("strands.wasm") - py_level = {"error": 40, "warn": 30, "info": 20, "debug": 10, "trace": 10}.get( - level, 20 - ) + py_level = { + "error": 40, + "warn": 30, + "info": 20, + "debug": 10, + "trace": 10, + }.get(level, 20) msg = f"{message} | {context}" if context else message logger.log(py_level, msg) + return log_fn +# --------------------------------------------------------------------------- +# Hook-provider no-op callbacks (return zero-value decisions) +# --------------------------------------------------------------------------- + + +def _hook_get_capabilities(store_ctx: typing.Any) -> list[str]: + return [] + + +def _cancel_decision() -> Record: + """Default no-op cancellable decision.""" + return _rec(cancel=False, **{"cancel-message": None}) + + +def _hook_before_invocation(store_ctx: typing.Any) -> Record: + return _cancel_decision() + + +def _hook_after_invocation(store_ctx: typing.Any) -> Record: + return _rec(resume=None) + + +def _hook_before_model_call( + store_ctx: typing.Any, projected_input_tokens: typing.Optional[int] +) -> Record: + return _cancel_decision() + + +def _hook_after_model_call( + store_ctx: typing.Any, + stop_reason: typing.Optional[str], + stop_data: typing.Optional[str], + error: typing.Optional[str], +) -> Record: + return _rec(retry=False) + + +def _hook_before_tools(store_ctx: typing.Any, message: str) -> Record: + return _cancel_decision() + + +def _hook_after_tools(store_ctx: typing.Any) -> Record: + return _rec() + + +def _hook_before_tool_call(store_ctx: typing.Any, tool_use: str) -> Record: + return _rec( + cancel=False, + **{"cancel-message": None, "tool-use": None, "selected-tool-name": None}, + ) + + +def _hook_after_tool_call( + store_ctx: typing.Any, tool_use: str, result: str, error: typing.Optional[str] +) -> Record: + return _rec(retry=False, result=None) + + +_HOOK_PROVIDER_FUNCS: list[tuple[str, typing.Callable]] = [ + ("get-capabilities", _hook_get_capabilities), + ("before-invocation", _hook_before_invocation), + ("after-invocation", _hook_after_invocation), + ("before-model-call", _hook_before_model_call), + ("after-model-call", _hook_after_model_call), + ("before-tools", _hook_before_tools), + ("after-tools", _hook_after_tools), + ("before-tool-call", _hook_before_tool_call), + ("after-tool-call", _hook_after_tool_call), +] + + # --------------------------------------------------------------------------- # WasmAgent — drop-in replacement for the former native Agent class # --------------------------------------------------------------------------- + class WasmAgent: """WASM-hosted agent with the same API as the former native ``Agent``.""" @@ -568,6 +661,9 @@ def __init__( tp.add_func("call-tools", _make_call_tools_fn(tool_dispatcher)) with root.add_instance("strands:agent/host-log") as hl: hl.add_func("log", _make_log_fn(log_handler)) + with root.add_instance("strands:agent/hook-provider") as hp: + for name, fn in _HOOK_PROVIDER_FUNCS: + hp.add_func(name, fn) # --- store --- store = Store(engine) @@ -584,7 +680,13 @@ def __init__( self._component = component # --- instantiate + construct agent (async, run synchronously) --- - agent_config = _build_agent_config(model, system_prompt, system_prompt_blocks, tools, conversation_manager_config) + agent_config = _build_agent_config( + model, + system_prompt, + system_prompt_blocks, + tools, + conversation_manager_config, + ) _run_sync(self._init_async(linker, store, component, agent_config)) async def _init_async( @@ -622,9 +724,7 @@ def _fn(name: str) -> Func: async def start_stream(self, input_text: str) -> typing.Any: args = _build_stream_args(input_text, None, None) - return await self._generate_fn.call_async( - self._store, self._agent_handle, args - ) + return await self._generate_fn.call_async(self._store, self._agent_handle, args) async def start_stream_with_options( self, @@ -633,13 +733,9 @@ async def start_stream_with_options( tool_choice: str | None, ) -> typing.Any: args = _build_stream_args(input_text, tools, tool_choice) - return await self._generate_fn.call_async( - self._store, self._agent_handle, args - ) + return await self._generate_fn.call_async(self._store, self._agent_handle, args) - async def next_events( - self, stream_handle: typing.Any - ) -> list[StreamEvent] | None: + async def next_events(self, stream_handle: typing.Any) -> list[StreamEvent] | None: raw = await self._read_next_fn.call_async(self._store, stream_handle) if raw is None: return None diff --git a/strands-py/tests/__init__.py b/strands-py/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/strands-py/tests/test_hook_provider_stubs.py b/strands-py/tests/test_hook_provider_stubs.py new file mode 100644 index 000000000..65ae43852 --- /dev/null +++ b/strands-py/tests/test_hook_provider_stubs.py @@ -0,0 +1,113 @@ +"""Unit tests for hook-provider no-op stubs in _wasm_host.py.""" + +from strands._wasm_host import ( + _cancel_decision, + _hook_get_capabilities, + _hook_before_invocation, + _hook_after_invocation, + _hook_before_model_call, + _hook_after_model_call, + _hook_before_tools, + _hook_after_tools, + _hook_before_tool_call, + _hook_after_tool_call, + _HOOK_PROVIDER_FUNCS, +) + + +class TestCancelDecision: + def test_produces_correct_record_shape(self): + result = _cancel_decision() + assert result.cancel is False + assert getattr(result, "cancel-message") is None + + +class TestHookBeforeToolCall: + def test_returns_all_four_decision_fields(self): + result = _hook_before_tool_call(None, "{}") + assert result.cancel is False + assert getattr(result, "cancel-message") is None + assert getattr(result, "tool-use") is None + assert getattr(result, "selected-tool-name") is None + + +class TestHookAfterToolCall: + def test_returns_retry_and_result_fields(self): + result = _hook_after_tool_call(None, "{}", "{}", None) + assert result.retry is False + assert result.result is None + + +class TestHookAfterModelCall: + def test_accepts_all_none_params(self): + result = _hook_after_model_call(None, None, None, None) + assert result.retry is False + + def test_accepts_string_params(self): + result = _hook_after_model_call( + None, "endTurn", '{"reason":"end-turn"}', "something failed" + ) + assert result.retry is False + + +class TestHookProviderFuncsList: + def test_completeness(self): + assert len(_HOOK_PROVIDER_FUNCS) == 9 + + def test_names(self): + names = [name for name, _ in _HOOK_PROVIDER_FUNCS] + assert names == [ + "get-capabilities", + "before-invocation", + "after-invocation", + "before-model-call", + "after-model-call", + "before-tools", + "after-tools", + "before-tool-call", + "after-tool-call", + ] + + +class TestHookGetCapabilities: + def test_returns_empty_list(self): + result = _hook_get_capabilities(None) + assert result == [] + + +class TestHookBeforeInvocation: + def test_returns_cancel_decision(self): + result = _hook_before_invocation(None) + assert result.cancel is False + assert getattr(result, "cancel-message") is None + + +class TestHookAfterInvocation: + def test_returns_resume_none(self): + result = _hook_after_invocation(None) + assert result.resume is None + + +class TestHookBeforeModelCall: + def test_returns_cancel_decision_with_none(self): + result = _hook_before_model_call(None, None) + assert result.cancel is False + assert getattr(result, "cancel-message") is None + + def test_returns_cancel_decision_with_int(self): + result = _hook_before_model_call(None, 42) + assert result.cancel is False + assert getattr(result, "cancel-message") is None + + +class TestHookBeforeTools: + def test_returns_cancel_decision(self): + result = _hook_before_tools(None, '{"role":"assistant","content":[]}') + assert result.cancel is False + assert getattr(result, "cancel-message") is None + + +class TestHookAfterTools: + def test_returns_empty_record(self): + result = _hook_after_tools(None) + assert result.__dict__ == {} diff --git a/strands-wasm/__fixtures__/hook-provider.ts b/strands-wasm/__fixtures__/hook-provider.ts new file mode 100644 index 000000000..a2b1dcd0c --- /dev/null +++ b/strands-wasm/__fixtures__/hook-provider.ts @@ -0,0 +1,28 @@ +import { vi } from 'vitest' +import { FunctionTool, type FunctionToolCallback } from '@strands-agents/sdk' +export const getCapabilities = vi.fn().mockReturnValue([]) +export const beforeInvocation = vi.fn().mockReturnValue({ cancel: false, cancelMessage: undefined }) +export const afterInvocation = vi.fn().mockReturnValue({ resume: undefined }) +export const beforeModelCall = vi.fn().mockReturnValue({ cancel: false, cancelMessage: undefined }) +export const afterModelCall = vi.fn().mockReturnValue({ retry: false }) +export const beforeTools = vi.fn().mockReturnValue({ cancel: false, cancelMessage: undefined }) +export const afterTools = vi.fn().mockReturnValue({}) +export const beforeToolCall = vi.fn().mockReturnValue({ + cancel: false, + cancelMessage: undefined, + toolUse: undefined, + selectedToolName: undefined, +}) +export const afterToolCall = vi.fn().mockReturnValue({ retry: false, result: undefined }) + +/** Factory for test tools with sensible defaults. */ +export function testTool( + overrides: { name?: string; description?: string; callback?: FunctionToolCallback } = {} +): FunctionTool { + return new FunctionTool({ + name: overrides.name ?? 'test_tool', + description: overrides.description ?? 'test', + inputSchema: { type: 'object', properties: {} }, + callback: overrides.callback ?? (() => [{ text: 'ok' }]), + }) +} diff --git a/strands-wasm/__tests__/hook-provider-bridge.test.ts b/strands-wasm/__tests__/hook-provider-bridge.test.ts new file mode 100644 index 000000000..44fa6f2d1 --- /dev/null +++ b/strands-wasm/__tests__/hook-provider-bridge.test.ts @@ -0,0 +1,642 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest' +import { HookProviderBridge } from '../entry' +import { Agent, ToolResultBlock, TextBlock } from '@strands-agents/sdk' +import { MockMessageModel } from '$/fixtures/mock-message-model' +import * as hookProvider from '../__fixtures__/hook-provider' + +describe('HookProviderBridge', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('capability negotiation', () => { + it('registers no hooks when capabilities list is empty', async () => { + hookProvider.getCapabilities.mockReturnValue([]) + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'hi' }) + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + await agent.invoke('hello') + + expect(hookProvider.beforeInvocation).not.toHaveBeenCalled() + expect(hookProvider.afterInvocation).not.toHaveBeenCalled() + expect(hookProvider.beforeModelCall).not.toHaveBeenCalled() + expect(hookProvider.afterModelCall).not.toHaveBeenCalled() + }) + + it('registers only declared capabilities', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-model-call', 'after-model-call']) + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'hi' }) + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + await agent.invoke('hello') + + expect(hookProvider.beforeModelCall).toHaveBeenCalled() + expect(hookProvider.afterModelCall).toHaveBeenCalled() + expect(hookProvider.beforeInvocation).not.toHaveBeenCalled() + expect(hookProvider.afterInvocation).not.toHaveBeenCalled() + }) + + it('calls after-tools hook when capability is declared and tools run', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-tools']) + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} }) + model.addTurn({ type: 'textBlock', text: 'done' }) + const tool = hookProvider.testTool() + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + await agent.invoke('use tool') + + expect(hookProvider.afterTools).toHaveBeenCalled() + }) + }) + + describe('cancel decisions', () => { + it('cancels invocation when before-invocation returns cancel=true', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-invocation']) + hookProvider.beforeInvocation.mockReturnValue({ cancel: true, cancelMessage: 'stopped by host' }) + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'hi' }) + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + const result = await agent.invoke('hello') + + expect(hookProvider.beforeInvocation).toHaveBeenCalled() + expect(result.stopReason).toBe('endTurn') + }) + + it('cancels model call when before-model-call returns cancel=true', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-model-call']) + hookProvider.beforeModelCall.mockReturnValue({ cancel: true, cancelMessage: undefined }) + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'hi' }) + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + const result = await agent.invoke('hello') + + expect(hookProvider.beforeModelCall).toHaveBeenCalled() + expect(result.stopReason).toBe('endTurn') + }) + }) + + describe('JSON parse safety', () => { + it('handles invalid JSON in resume field gracefully', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-invocation']) + hookProvider.afterInvocation.mockReturnValue({ resume: 'not-valid-json{' }) + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'hi' }) + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + + // Should not throw — invalid JSON is logged and ignored + await expect(agent.invoke('hello')).resolves.toBeDefined() + }) + + it('handles invalid JSON in toolUse field gracefully', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-tool-call']) + hookProvider.beforeToolCall.mockReturnValue({ + cancel: false, + cancelMessage: undefined, + toolUse: 'invalid{json', + selectedToolName: undefined, + }) + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} }) + model.addTurn({ type: 'textBlock', text: 'done' }) + let receivedInput: unknown + const tool = hookProvider.testTool({ + callback: (input) => { + receivedInput = input + return [{ text: 'ok' }] + }, + }) + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + + await expect(agent.invoke('use tool')).resolves.toBeDefined() + expect(receivedInput).toStrictEqual({}) + }) + + it('handles invalid JSON in result field gracefully', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-tool-call']) + hookProvider.afterToolCall.mockReturnValue({ retry: false, result: 'bad-json{{' }) + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} }) + model.addTurn({ type: 'textBlock', text: 'done' }) + const tool = hookProvider.testTool() + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + + await expect(agent.invoke('use tool')).resolves.toBeDefined() + const toolResult = agent.messages + .flatMap((m) => m.content) + .find((b) => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1') + expect(toolResult).toBeDefined() + }) + }) + + describe('toolUse rewrite', () => { + it('tool receives rewritten input from before-tool-call decision', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-tool-call']) + hookProvider.beforeToolCall.mockReturnValue({ + cancel: false, + cancelMessage: undefined, + toolUse: JSON.stringify({ name: 'test_tool', toolUseId: 'tu-1', input: { rewritten: true } }), + selectedToolName: undefined, + }) + + let receivedInput: unknown + const tool = hookProvider.testTool({ + callback: (input) => { + receivedInput = input + return [{ text: 'ok' }] + }, + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: { original: true } }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + await agent.invoke('use tool') + + expect(receivedInput).toStrictEqual({ rewritten: true }) + }) + + it('partial rewrite preserves fields not present in host JSON', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-tool-call']) + hookProvider.beforeToolCall.mockReturnValue({ + cancel: false, + cancelMessage: undefined, + toolUse: JSON.stringify({ input: { injected: 'context' } }), + selectedToolName: undefined, + }) + + let receivedInput: unknown + const tool = hookProvider.testTool({ + callback: (input) => { + receivedInput = input + return [{ text: 'ok' }] + }, + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-original', input: { original: true } }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + await agent.invoke('use tool') + + expect(receivedInput).toStrictEqual({ injected: 'context' }) + const toolUseMessage = agent.messages.find((m) => m.content.some((b) => b.type === 'toolUseBlock')) + const toolUseBlock = toolUseMessage!.content.find((b) => b.type === 'toolUseBlock')! + expect(toolUseBlock.name).toBe('test_tool') + expect(toolUseBlock.toolUseId).toBe('tu-original') + }) + + it('empty object rewrite preserves all original fields', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-tool-call']) + hookProvider.beforeToolCall.mockReturnValue({ + cancel: false, + cancelMessage: undefined, + toolUse: JSON.stringify({}), + selectedToolName: undefined, + }) + + let receivedInput: unknown + const tool = hookProvider.testTool({ + callback: (input) => { + receivedInput = input + return [{ text: 'ok' }] + }, + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-keep', input: { keep: 'me' } }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + await agent.invoke('use tool') + + expect(receivedInput).toStrictEqual({ keep: 'me' }) + const toolUseMessage = agent.messages.find((m) => m.content.some((b) => b.type === 'toolUseBlock')) + const toolUseBlock = toolUseMessage!.content.find((b) => b.type === 'toolUseBlock')! + expect(toolUseBlock.name).toBe('test_tool') + expect(toolUseBlock.toolUseId).toBe('tu-keep') + }) + }) + + describe('selectedToolName', () => { + it('executes the replacement tool from registry', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-tool-call']) + hookProvider.beforeToolCall.mockReturnValue({ + cancel: false, + cancelMessage: undefined, + toolUse: undefined, + selectedToolName: 'replacement_tool', + }) + + let originalExecuted = false + let replacementExecuted = false + + const originalTool = hookProvider.testTool({ + name: 'original_tool', + description: 'original', + callback: () => { + originalExecuted = true + return [{ text: 'original' }] + }, + }) + const replacementTool = hookProvider.testTool({ + name: 'replacement_tool', + description: 'replacement', + callback: () => { + replacementExecuted = true + return [{ text: 'replaced' }] + }, + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'original_tool', toolUseId: 'tu-1', input: {} }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [originalTool, replacementTool], printer: false }) + await agent.invoke('use tool') + + expect(originalExecuted).toBe(false) + expect(replacementExecuted).toBe(true) + }) + + it('selectedToolName combined with toolUse rewrite applies both', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-tool-call']) + hookProvider.beforeToolCall.mockReturnValue({ + cancel: false, + cancelMessage: undefined, + toolUse: JSON.stringify({ input: { redirected: true } }), + selectedToolName: 'replacement_tool', + }) + + let replacementInput: unknown + const originalTool = hookProvider.testTool({ + name: 'original_tool', + description: 'original', + callback: () => [{ text: 'original' }], + }) + const replacementTool = hookProvider.testTool({ + name: 'replacement_tool', + description: 'replacement', + callback: (input) => { + replacementInput = input + return [{ text: 'replaced' }] + }, + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'original_tool', toolUseId: 'tu-1', input: { original: true } }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [originalTool, replacementTool], printer: false }) + await agent.invoke('use tool') + + expect(replacementInput).toStrictEqual({ redirected: true }) + }) + }) + + describe('result replacement', () => { + it('replaces tool result in conversation history via after-tool-call', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-tool-call']) + hookProvider.afterToolCall.mockReturnValue({ + retry: false, + result: JSON.stringify({ + toolResult: { + toolUseId: 'tu-1', + status: 'success', + content: [{ text: '[REDACTED]' }], + }, + }), + }) + + const tool = hookProvider.testTool({ + callback: () => [{ text: 'SECRET_VALUE' }], + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + await agent.invoke('use tool') + + const toolResultMessage = agent.messages.find((m) => + m.content.some((b) => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1') + ) + expect(toolResultMessage).toBeDefined() + const block = toolResultMessage!.content.find( + (b): b is ToolResultBlock => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1' + ) + expect(block!.content[0]).toStrictEqual(new TextBlock('[REDACTED]')) + }) + + it('handles result JSON missing toolResult wrapper gracefully', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-tool-call']) + // Missing the { toolResult: ... } wrapper — fromJSON will throw, code should catch and keep original + hookProvider.afterToolCall.mockReturnValue({ + retry: false, + result: JSON.stringify({ toolUseId: 'tu-1', status: 'success', content: [{ text: 'flat' }] }), + }) + + const tool = hookProvider.testTool({ + callback: () => [{ text: 'original output' }], + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + + // Should not throw — malformed result is logged and original result preserved + await expect(agent.invoke('use tool')).resolves.toBeDefined() + + const toolResultMessage = agent.messages.find((m) => + m.content.some((b) => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1') + ) + expect(toolResultMessage).toBeDefined() + const block = toolResultMessage!.content.find( + (b): b is ToolResultBlock => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1' + ) + // Original result preserved since replacement JSON was malformed + expect(block).toBeDefined() + }) + }) + + describe('retry', () => { + it('after-tool-call retry re-executes the tool', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-tool-call']) + let callCount = 0 + hookProvider.afterToolCall.mockImplementation(() => { + callCount++ + if (callCount === 1) return { retry: true, result: undefined } + return { retry: false, result: undefined } + }) + + let toolCallCount = 0 + const tool = hookProvider.testTool({ + callback: () => { + toolCallCount++ + return [{ text: `call-${toolCallCount}` }] + }, + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + await agent.invoke('use tool') + + expect(toolCallCount).toBe(2) + }) + + it('after-model-call retry re-invokes the model', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-model-call']) + let hookCallCount = 0 + hookProvider.afterModelCall.mockImplementation(() => { + hookCallCount++ + if (hookCallCount === 1) return { retry: true } + return { retry: false } + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'first' }) + model.addTurn({ type: 'textBlock', text: 'second' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + await agent.invoke('hello') + + expect(model.callCount).toBe(2) + }) + + it('after-model-call consecutive retries all re-invoke the model', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-model-call']) + let hookCallCount = 0 + hookProvider.afterModelCall.mockImplementation(() => { + hookCallCount++ + if (hookCallCount <= 3) return { retry: true } + return { retry: false } + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'attempt-1' }) + model.addTurn({ type: 'textBlock', text: 'attempt-2' }) + model.addTurn({ type: 'textBlock', text: 'attempt-3' }) + model.addTurn({ type: 'textBlock', text: 'final' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + await agent.invoke('hello') + + expect(model.callCount).toBe(4) + expect(hookCallCount).toBe(4) + }) + }) + + describe('resume', () => { + it('after-invocation resume re-enters the agent loop with new input', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-invocation']) + let invocationCount = 0 + hookProvider.afterInvocation.mockImplementation(() => { + invocationCount++ + if (invocationCount === 1) return { resume: JSON.stringify('follow-up question') } + return { resume: undefined } + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'first response' }) + model.addTurn({ type: 'textBlock', text: 'second response' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + await agent.invoke('initial') + + expect(model.callCount).toBe(2) + const userMessages = agent.messages.filter((m) => m.role === 'user') + expect(userMessages).toHaveLength(2) + }) + + it('after-invocation resume works with content block array JSON', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-invocation']) + let invocationCount = 0 + const resumePayload = [{ role: 'user', content: [{ type: 'textBlock', text: 'structured follow-up' }] }] + hookProvider.afterInvocation.mockImplementation(() => { + invocationCount++ + if (invocationCount === 1) { + return { resume: JSON.stringify(resumePayload) } + } + return { resume: undefined } + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'first response' }) + model.addTurn({ type: 'textBlock', text: 'second response' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + await agent.invoke('initial') + + expect(model.callCount).toBe(2) + }) + + it('after-invocation resume works with array JSON', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-invocation']) + let invocationCount = 0 + hookProvider.afterInvocation.mockImplementation(() => { + invocationCount++ + if (invocationCount === 1) return { resume: JSON.stringify([{ type: 'textBlock', text: 'follow-up' }]) } + return { resume: undefined } + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'first response' }) + model.addTurn({ type: 'textBlock', text: 'second response' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + await agent.invoke('initial') + + expect(model.callCount).toBe(2) + const userMessages = agent.messages.filter((m) => m.role === 'user') + expect(userMessages).toHaveLength(2) + }) + }) + + describe('before-tools cancel', () => { + it('cancels all tool calls in the batch', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-tools']) + hookProvider.beforeTools.mockReturnValue({ cancel: true, cancelMessage: 'tools disabled' }) + + let toolExecuted = false + const tool = hookProvider.testTool({ + callback: () => { + toolExecuted = true + return [{ text: 'ok' }] + }, + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + await agent.invoke('use tool') + + expect(toolExecuted).toBe(false) + }) + + it('cancels multiple tool calls in the same turn', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-tools']) + hookProvider.beforeTools.mockReturnValue({ cancel: true, cancelMessage: 'batch blocked' }) + + let toolAExecuted = false + let toolBExecuted = false + const toolA = hookProvider.testTool({ + name: 'tool_a', + description: 'tool A', + callback: () => { + toolAExecuted = true + return [{ text: 'a' }] + }, + }) + const toolB = hookProvider.testTool({ + name: 'tool_b', + description: 'tool B', + callback: () => { + toolBExecuted = true + return [{ text: 'b' }] + }, + }) + + const model = new MockMessageModel() + model.addTurn([ + { type: 'toolUseBlock', name: 'tool_a', toolUseId: 'tu-a', input: {} }, + { type: 'toolUseBlock', name: 'tool_b', toolUseId: 'tu-b', input: {} }, + ]) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [toolA, toolB], printer: false }) + await agent.invoke('use tools') + + expect(toolAExecuted).toBe(false) + expect(toolBExecuted).toBe(false) + }) + }) + + describe('cancel message propagation', () => { + it('before-invocation cancel message appears in response', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-invocation']) + hookProvider.beforeInvocation.mockReturnValue({ cancel: true, cancelMessage: 'blocked by policy' }) + + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'should not reach' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + const result = await agent.invoke('hello') + + const lastAssistant = agent.messages.findLast((m) => m.role === 'assistant') + expect(lastAssistant).toBeDefined() + const textBlock = lastAssistant!.content.find((b) => b.type === 'textBlock') + expect(textBlock).toBeDefined() + expect((textBlock as TextBlock).text).toBe('blocked by policy') + expect(model.callCount).toBe(0) + }) + + it('before-tool-call cancel message becomes tool result error', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-tool-call']) + hookProvider.beforeToolCall.mockReturnValue({ + cancel: true, + cancelMessage: 'tool blocked', + toolUse: undefined, + selectedToolName: undefined, + }) + + let toolExecuted = false + const tool = hookProvider.testTool({ + callback: () => { + toolExecuted = true + return [{ text: 'ok' }] + }, + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + await agent.invoke('use tool') + + expect(toolExecuted).toBe(false) + const toolResultMessage = agent.messages.find((m) => + m.content.some((b) => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1') + ) + expect(toolResultMessage).toBeDefined() + const block = toolResultMessage!.content.find( + (b): b is ToolResultBlock => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1' + ) + expect(block!.status).toBe('error') + }) + }) +}) diff --git a/strands-wasm/__tests__/hook-provider-forwarding.test.ts b/strands-wasm/__tests__/hook-provider-forwarding.test.ts new file mode 100644 index 000000000..eb4b359de --- /dev/null +++ b/strands-wasm/__tests__/hook-provider-forwarding.test.ts @@ -0,0 +1,154 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest' +import { HookProviderBridge } from '../entry' +import { Agent, FunctionTool } from '@strands-agents/sdk' +import { MockMessageModel } from '$/fixtures/mock-message-model' +import * as hookProvider from '../__fixtures__/hook-provider' + +describe('HookProviderBridge argument forwarding', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('before-model-call', () => { + it('forwards projectedInputTokens to host function', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-model-call']) + hookProvider.beforeModelCall.mockReturnValue({ cancel: false, cancelMessage: undefined }) + + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'hi' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + await agent.invoke('hello') + + expect(hookProvider.beforeModelCall).toHaveBeenCalled() + const firstArg = hookProvider.beforeModelCall.mock.calls[0][0] + expect(firstArg === undefined || typeof firstArg === 'number').toBe(true) + }) + }) + + describe('after-model-call', () => { + it('forwards stopReason on successful completion', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-model-call']) + hookProvider.afterModelCall.mockReturnValue({ retry: false }) + + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'hi' }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + await agent.invoke('hello') + + expect(hookProvider.afterModelCall).toHaveBeenCalled() + const [stopReason, , error] = hookProvider.afterModelCall.mock.calls[0] + expect(stopReason).toBe('endTurn') + expect(error).toBeUndefined() + }) + + it('forwards error message on model failure', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-model-call']) + hookProvider.afterModelCall.mockReturnValue({ retry: false }) + + const model = new MockMessageModel() + model.addTurn(new Error('model exploded')) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], printer: false }) + + await expect(agent.invoke('hello')).rejects.toThrow() + + expect(hookProvider.afterModelCall).toHaveBeenCalled() + const [, , error] = hookProvider.afterModelCall.mock.calls[0] + expect(error).toBe('model exploded') + }) + }) + + describe('before-tools', () => { + it('forwards serialized message to host function', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-tools']) + hookProvider.beforeTools.mockReturnValue({ cancel: false, cancelMessage: undefined }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'my_tool', toolUseId: 'tu-1', input: { x: 1 } }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const tool = new FunctionTool({ + name: 'my_tool', + description: 'test tool', + inputSchema: { type: 'object', properties: {} }, + callback: () => [{ text: 'ok' }], + }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + await agent.invoke('use tool') + + expect(hookProvider.beforeTools).toHaveBeenCalled() + const firstArg = hookProvider.beforeTools.mock.calls[0][0] + expect(typeof firstArg).toBe('string') + const parsed = JSON.parse(firstArg as string) + expect(parsed).toBeDefined() + expect(typeof parsed).toBe('object') + }) + }) + + describe('after-tool-call', () => { + it('forwards serialized toolUse, result, and error', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-tool-call']) + hookProvider.afterToolCall.mockReturnValue({ retry: false, result: undefined }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'my_tool', toolUseId: 'tu-123', input: { key: 'value' } }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const tool = new FunctionTool({ + name: 'my_tool', + description: 'test tool', + inputSchema: { type: 'object', properties: {} }, + callback: () => [{ text: 'tool output' }], + }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + await agent.invoke('use tool') + + expect(hookProvider.afterToolCall).toHaveBeenCalled() + const [toolUseJson, resultJson, error] = hookProvider.afterToolCall.mock.calls[0] + + const toolUse = JSON.parse(toolUseJson as string) + expect(toolUse.name).toBe('my_tool') + expect(toolUse.toolUseId).toBe('tu-123') + + const result = JSON.parse(resultJson as string) + expect(typeof result).toBe('object') + + expect(error).toBeUndefined() + }) + + it('forwards error string when tool throws', async () => { + hookProvider.getCapabilities.mockReturnValue(['after-tool-call']) + hookProvider.afterToolCall.mockReturnValue({ retry: false, result: undefined }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'error_tool', toolUseId: 'tu-err', input: {} }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const tool = new FunctionTool({ + name: 'error_tool', + description: 'tool that throws', + inputSchema: { type: 'object', properties: {} }, + callback: () => { + throw new Error('tool crashed') + }, + }) + + const bridge = new HookProviderBridge() + const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false }) + await agent.invoke('use tool') + + expect(hookProvider.afterToolCall).toHaveBeenCalled() + const [, , error] = hookProvider.afterToolCall.mock.calls[0] + expect(error).toBe('tool crashed') + }) + }) +}) diff --git a/strands-wasm/__tests__/hook-provider-regression.test.ts b/strands-wasm/__tests__/hook-provider-regression.test.ts new file mode 100644 index 000000000..20f97b92b --- /dev/null +++ b/strands-wasm/__tests__/hook-provider-regression.test.ts @@ -0,0 +1,161 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest' +import { HookProviderBridge, LifecycleBridge } from '../entry' +import { Agent } from '@strands-agents/sdk' +import { MockMessageModel } from '$/fixtures/mock-message-model' +import * as hookProvider from '../__fixtures__/hook-provider' +import { testTool } from '../__fixtures__/hook-provider' + +describe('HookProviderBridge regression', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + function expectNoHooksCalled(): void { + expect(hookProvider.beforeInvocation).not.toHaveBeenCalled() + expect(hookProvider.afterInvocation).not.toHaveBeenCalled() + expect(hookProvider.beforeModelCall).not.toHaveBeenCalled() + expect(hookProvider.afterModelCall).not.toHaveBeenCalled() + expect(hookProvider.beforeTools).not.toHaveBeenCalled() + expect(hookProvider.afterTools).not.toHaveBeenCalled() + expect(hookProvider.beforeToolCall).not.toHaveBeenCalled() + expect(hookProvider.afterToolCall).not.toHaveBeenCalled() + } + + it('text-only invocation unchanged with HookProviderBridge present', async () => { + hookProvider.getCapabilities.mockReturnValue([]) + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'hello' }) + model.addTurn({ type: 'textBlock', text: 'unreachable' }) + const agent = new Agent({ model, plugins: [new HookProviderBridge()], printer: false }) + const result = await agent.invoke('hi') + + expect(hookProvider.getCapabilities).toHaveBeenCalledTimes(1) + expectNoHooksCalled() + expect(model.callCount).toBe(1) + const textBlock = result.lastMessage.content.find((b) => b.type === 'textBlock') + expect(textBlock).toBeDefined() + expect(textBlock!.text).toBe('hello') + }) + + it('tool invocation unchanged with HookProviderBridge present', async () => { + hookProvider.getCapabilities.mockReturnValue([]) + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + let toolCalled = false + const tool = testTool({ + callback: () => { + toolCalled = true + return [{ text: 'ok' }] + }, + }) + + const agent = new Agent({ model, plugins: [new HookProviderBridge()], tools: [tool], printer: false }) + await agent.invoke('use tool') + + expect(toolCalled).toBe(true) + expect(model.callCount).toBe(2) + expect(hookProvider.getCapabilities).toHaveBeenCalledTimes(1) + expectNoHooksCalled() + }) + + it('LifecycleBridge events unaffected by HookProviderBridge coexistence', async () => { + hookProvider.getCapabilities.mockReturnValue([]) + const lifecycle = new LifecycleBridge() + const hookBridge = new HookProviderBridge() + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, plugins: [lifecycle, hookBridge], printer: false }) + await agent.invoke('hi') + + const events = lifecycle.drain() + const eventTypes = events.map((e) => e.val.eventType) + + expect(eventTypes).toContain('initialized') + expect(eventTypes).toContain('before-invocation') + expect(eventTypes).toContain('before-model-call') + expect(eventTypes).toContain('after-model-call') + expect(eventTypes).toContain('message-added') + expect(eventTypes).toContain('after-invocation') + }) + + it('LifecycleBridge before-tools/after-tools events with both plugins during tool turn', async () => { + hookProvider.getCapabilities.mockReturnValue([]) + const lifecycle = new LifecycleBridge() + const hookBridge = new HookProviderBridge() + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const tool = testTool() + + const agent = new Agent({ model, plugins: [lifecycle, hookBridge], tools: [tool], printer: false }) + await agent.invoke('use tool') + + const events = lifecycle.drain() + const eventTypes = events.map((e) => e.val.eventType) + + expect(eventTypes).toContain('before-tools') + expect(eventTypes).toContain('after-tools') + expect(eventTypes).toContain('before-tool-call') + expect(eventTypes).toContain('after-tool-call') + }) +}) + +describe('HookProviderBridge error paths', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('selectedToolName referencing non-existent tool uses original tool', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-tool-call']) + hookProvider.beforeToolCall.mockReturnValue({ + cancel: false, + cancelMessage: undefined, + toolUse: undefined, + selectedToolName: 'nonexistent_tool', + }) + + let originalExecuted = false + const originalTool = testTool({ + name: 'original_tool', + description: 'original', + callback: () => { + originalExecuted = true + return [{ text: 'original result' }] + }, + }) + + const model = new MockMessageModel() + model.addTurn({ type: 'toolUseBlock', name: 'original_tool', toolUseId: 'tu-1', input: {} }) + model.addTurn({ type: 'textBlock', text: 'done' }) + + const agent = new Agent({ model, plugins: [new HookProviderBridge()], tools: [originalTool], printer: false }) + await agent.invoke('use tool') + + expect(originalExecuted).toBe(true) + expect(hookProvider.beforeToolCall).toHaveBeenCalled() + }) + + it('capabilities are read once at init, not per-invocation', async () => { + hookProvider.getCapabilities.mockReturnValue(['before-invocation']) + hookProvider.beforeInvocation.mockReturnValue({ cancel: false, cancelMessage: undefined }) + + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'first response' }) + model.addTurn({ type: 'textBlock', text: 'second response' }) + + const agent = new Agent({ model, plugins: [new HookProviderBridge()], printer: false }) + await agent.invoke('first') + + hookProvider.getCapabilities.mockClear() + hookProvider.beforeInvocation.mockClear() + hookProvider.beforeInvocation.mockReturnValue({ cancel: false, cancelMessage: undefined }) + + await agent.invoke('second') + + expect(hookProvider.getCapabilities).toHaveBeenCalledTimes(0) + expect(hookProvider.beforeInvocation).toHaveBeenCalledTimes(1) + }) +}) diff --git a/strands-wasm/__tests__/lifecycle.test.ts b/strands-wasm/__tests__/lifecycle.test.ts index 9a1f49a80..b6a77d30e 100644 --- a/strands-wasm/__tests__/lifecycle.test.ts +++ b/strands-wasm/__tests__/lifecycle.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect } from 'vitest' -import { LifecycleBridge } from '../../entry' +import { LifecycleBridge } from '../entry' import { Agent, FunctionTool } from '@strands-agents/sdk' import { MockMessageModel } from '$/fixtures/mock-message-model' @@ -136,6 +136,18 @@ describe('LifecycleBridge', () => { const events = bridge.drain() + const beforeTools = events.find((e) => e.val.eventType === 'before-tools') + expect(beforeTools).toStrictEqual({ + tag: 'lifecycle', + val: { eventType: 'before-tools', toolUse: undefined, toolResult: undefined }, + }) + + const afterTools = events.find((e) => e.val.eventType === 'after-tools') + expect(afterTools).toStrictEqual({ + tag: 'lifecycle', + val: { eventType: 'after-tools', toolUse: undefined, toolResult: undefined }, + }) + const beforeToolCall = events.find((e) => e.val.eventType === 'before-tool-call') expect(beforeToolCall).toStrictEqual({ tag: 'lifecycle', diff --git a/strands-wasm/__tests__/mapping.test.ts b/strands-wasm/__tests__/mapping.test.ts index a71b19ddd..356ffabcb 100644 --- a/strands-wasm/__tests__/mapping.test.ts +++ b/strands-wasm/__tests__/mapping.test.ts @@ -10,7 +10,7 @@ import { mapToolStreamEvent, parseInput, parseSaveLatestStrategy, -} from '../../entry' +} from '../entry' import type { AgentStreamEvent, ModelStreamEvent, StopReason } from '@strands-agents/sdk' import { ToolStreamEvent, ToolUseBlock, ToolResultBlock, TextBlock, ReasoningBlock } from '@strands-agents/sdk' diff --git a/strands-wasm/__tests__/stream.test.ts b/strands-wasm/__tests__/stream.test.ts index 1a20ea9d6..c2ec02513 100644 --- a/strands-wasm/__tests__/stream.test.ts +++ b/strands-wasm/__tests__/stream.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect, vi } from 'vitest' -import { api, LifecycleBridge } from '../../entry' +import { api, LifecycleBridge } from '../entry' import { Agent } from '@strands-agents/sdk' import { MockMessageModel } from '$/fixtures/mock-message-model' diff --git a/strands-wasm/__tests__/tool-bridge.test.ts b/strands-wasm/__tests__/tool-bridge.test.ts index a0f0fb92f..77b9f9d4c 100644 --- a/strands-wasm/__tests__/tool-bridge.test.ts +++ b/strands-wasm/__tests__/tool-bridge.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect, vi, beforeEach } from 'vitest' -import { createTools } from '../../entry' +import { createTools } from '../entry' import { callTool } from 'strands:agent/tool-provider' const emptyToolContext = { toolUse: { toolUseId: '' } } as any diff --git a/strands-wasm/entry.ts b/strands-wasm/entry.ts index 67e507443..aa6699ecd 100644 --- a/strands-wasm/entry.ts +++ b/strands-wasm/entry.ts @@ -12,6 +12,7 @@ /// /// /// +/// import type { AgentConfig, @@ -29,6 +30,17 @@ import type { import { callTool } from 'strands:agent/tool-provider' import { log as hostLog } from 'strands:agent/host-log' +import { + getCapabilities as hookGetCapabilities, + beforeInvocation as hookBeforeInvocation, + afterInvocation as hookAfterInvocation, + beforeModelCall as hookBeforeModelCall, + afterModelCall as hookAfterModelCall, + beforeTools as hookBeforeTools, + afterTools as hookAfterTools, + beforeToolCall as hookBeforeToolCall, + afterToolCall as hookAfterToolCall, +} from 'strands:agent/hook-provider' import { Agent, FunctionTool, SessionManager, FileStorage } from '@strands-agents/sdk' import { S3Storage } from '@strands-agents/sdk/session/s3-storage' import { AnthropicModel } from '@strands-agents/sdk/models/anthropic' @@ -69,7 +81,10 @@ import { BeforeInvocationEvent, BeforeModelCallEvent, BeforeToolCallEvent, + BeforeToolsEvent, + AfterToolsEvent, MessageAddedEvent, + ToolResultBlock, } from '@strands-agents/sdk' // All log calls go through `hostLog` (the WIT import). The host can @@ -432,6 +447,8 @@ class LifecycleBridge implements Plugin { agent.addHook(AfterInvocationEvent, () => this.push('after-invocation')) agent.addHook(BeforeModelCallEvent, () => this.push('before-model-call')) agent.addHook(AfterModelCallEvent, () => this.push('after-model-call')) + agent.addHook(BeforeToolsEvent, () => this.push('before-tools')) + agent.addHook(AfterToolsEvent, () => this.push('after-tools')) agent.addHook(MessageAddedEvent, () => this.push('message-added')) agent.addHook(BeforeToolCallEvent, (event) => { @@ -448,6 +465,117 @@ class LifecycleBridge implements Plugin { } } +/** Bridges TS SDK hooks to WIT hook-provider imports for bidirectional host control. */ +class HookProviderBridge implements Plugin { + readonly name = 'strands:hook-provider-bridge' + + /** Apply a cancel decision from a hook-provider response to a cancellable event. */ + private applyCancel( + event: { cancel?: boolean | string }, + decision: { cancel: boolean; cancelMessage?: string } + ): void { + if (decision.cancel) { + event.cancel = decision.cancelMessage ?? true + } + } + + initAgent(agent: LocalAgent): void { + const caps = new Set(hookGetCapabilities()) + + if (caps.has('before-invocation')) { + agent.addHook(BeforeInvocationEvent, (event) => { + const decision = hookBeforeInvocation() + this.applyCancel(event, decision) + }) + } + + if (caps.has('after-invocation')) { + agent.addHook(AfterInvocationEvent, (event) => { + const decision = hookAfterInvocation() + if (decision.resume) { + try { + event.resume = JSON.parse(decision.resume) + } catch { + glog('warn', `hook_name= | invalid JSON in resume field, ignoring`) + } + } + }) + } + + if (caps.has('before-model-call')) { + agent.addHook(BeforeModelCallEvent, (event) => { + const decision = hookBeforeModelCall(event.projectedInputTokens) + this.applyCancel(event, decision) + }) + } + + if (caps.has('after-model-call')) { + agent.addHook(AfterModelCallEvent, (event) => { + const decision = hookAfterModelCall(event.stopData?.stopReason, undefined, event.error?.message) + if (decision.retry) { + event.retry = true + } + }) + } + + if (caps.has('before-tools')) { + agent.addHook(BeforeToolsEvent, (event) => { + const decision = hookBeforeTools(JSON.stringify(event.message)) + this.applyCancel(event, decision) + }) + } + + if (caps.has('after-tools')) { + agent.addHook(AfterToolsEvent, () => void hookAfterTools()) + } + + if (caps.has('before-tool-call')) { + agent.addHook(BeforeToolCallEvent, (event) => { + const decision = hookBeforeToolCall(JSON.stringify(event.toolUse)) + this.applyCancel(event, decision) + if (decision.toolUse) { + try { + const rewritten = JSON.parse(decision.toolUse) + if (rewritten.name !== undefined) event.toolUse.name = rewritten.name + if (rewritten.toolUseId !== undefined) event.toolUse.toolUseId = rewritten.toolUseId + if (rewritten.input !== undefined) event.toolUse.input = rewritten.input + } catch { + glog('warn', `hook_name= | invalid JSON in toolUse field, ignoring`) + } + } + if (decision.selectedToolName) { + const tool = agent.toolRegistry.get(decision.selectedToolName) + if (tool) { + event.selectedTool = tool + } else { + glog( + 'warn', + `hook_name=, tool_name=<${decision.selectedToolName}> | hook selected unknown tool, ignoring` + ) + } + } + }) + } + + if (caps.has('after-tool-call')) { + agent.addHook(AfterToolCallEvent, (event) => { + const error = event.error?.message + const decision = hookAfterToolCall(JSON.stringify(event.toolUse), JSON.stringify(event.result), error) + if (decision.retry) { + event.retry = true + } + if (decision.result) { + try { + event.result = ToolResultBlock.fromJSON(JSON.parse(decision.result)) + } catch { + glog('warn', `hook_name= | invalid JSON in result field, ignoring`) + } + } + }) + } + } +} + /** Parse user input — JSON arrays pass through, plain strings stay as-is. */ function parseInput(input: string): InvokeArgs { try { @@ -554,7 +682,7 @@ class AgentImpl { this.sessionManager = createSessionManager(config) const conversationManager = createConversationManager(config) - const plugins: Plugin[] = [this.lifecycleBridge] + const plugins: Plugin[] = [this.lifecycleBridge, new HookProviderBridge()] this.agent = new Agent({ model, @@ -726,6 +854,7 @@ export { parseInput, createTools, LifecycleBridge, + HookProviderBridge, parseSaveLatestStrategy, createToolChoiceProxy, } diff --git a/strands-wasm/vitest.config.ts b/strands-wasm/vitest.config.ts index a22d5b46e..a3de572db 100644 --- a/strands-wasm/vitest.config.ts +++ b/strands-wasm/vitest.config.ts @@ -13,6 +13,7 @@ export default defineConfig({ alias: { 'strands:agent/tool-provider': resolve(__dirname, '__fixtures__/tool-provider.ts'), 'strands:agent/host-log': resolve(__dirname, '__fixtures__/host-log.ts'), + 'strands:agent/hook-provider': resolve(__dirname, '__fixtures__/hook-provider.ts'), '$/fixtures': resolve(__dirname, '../strands-ts/src/__fixtures__'), }, }, diff --git a/wit/agent.wit b/wit/agent.wit index d6fa57208..77a169a02 100644 --- a/wit/agent.wit +++ b/wit/agent.wit @@ -70,6 +70,8 @@ interface types { after-invocation, before-model-call, after-model-call, + before-tools, + after-tools, before-tool-call, after-tool-call, message-added, @@ -252,6 +254,75 @@ interface host-log { log: func(entry: log-entry); } +/// Bidirectional hook control from host to guest. +/// The guest calls these at lifecycle points and receives decisions +/// that can cancel, retry, redirect, or mutate the operation. +interface hook-provider { + /// Capability negotiation. The guest calls this once during init. + /// The host returns which hooks it supports. For unsupported hooks, + /// the guest skips the WIT import call and uses the default decision. + enum hook-capability { + before-invocation, + after-invocation, + before-model-call, + after-model-call, + before-tools, + after-tools, + before-tool-call, + after-tool-call, + } + + get-capabilities: func() -> list; + + /// Invariant: cancel-message is only meaningful when cancel is true. + record invocation-decision { + cancel: bool, + cancel-message: option, + } + + record after-invocation-decision { + resume: option, + } + + record model-call-decision { + cancel: bool, + cancel-message: option, + } + + record after-model-call-decision { + retry: bool, + } + + /// Dedicated decision type for before-tools (cancel all tool calls in a batch). + record tools-decision { + cancel: bool, + cancel-message: option, + } + + record tool-call-decision { + cancel: bool, + cancel-message: option, + tool-use: option, + selected-tool-name: option, + } + + record after-tool-call-decision { + retry: bool, + %result: option, + } + + record after-tools-decision {} + + before-invocation: func() -> invocation-decision; + after-invocation: func() -> after-invocation-decision; + before-model-call: func(projected-input-tokens: option) -> model-call-decision; + after-model-call: func(stop-reason: option, stop-data: option, error: option) -> after-model-call-decision; + before-tools: func(message: string) -> tools-decision; + after-tools: func() -> after-tools-decision; + before-tool-call: func(tool-use: string) -> tool-call-decision; + after-tool-call: func(tool-use: string, %result: string, error: option) -> after-tool-call-decision; +} + /// The main API exported by the WASM guest. interface api { use types.{agent-config, stream-event, stream-args, respond-args, set-messages-args}; @@ -276,6 +347,7 @@ interface api { world agent { import tool-provider; import host-log; + import hook-provider; export api; }