diff --git a/strands-py/strands/_wasm_host.py b/strands-py/strands/_wasm_host.py
index 696b685b5..ac6448044 100644
--- a/strands-py/strands/_wasm_host.py
+++ b/strands-py/strands/_wasm_host.py
@@ -101,11 +101,13 @@ def _run_sync(coro: typing.Coroutine) -> typing.Any:
# Already inside a running loop — run in a fresh thread to avoid nesting
result = [None]
exc = [None]
+
def _target():
try:
result[0] = asyncio.run(coro)
except Exception as e:
exc[0] = e
+
t = threading.Thread(target=_target)
t.start()
t.join()
@@ -169,6 +171,7 @@ def _get_engine_and_component() -> tuple[Engine, Component]:
# Record / Variant builders (Python → WIT kebab-case)
# ---------------------------------------------------------------------------
+
def _rec(**kwargs: typing.Any) -> Record:
"""Build a wasmtime-py Record with the given kebab-case fields."""
r = Record.__new__(Record)
@@ -178,7 +181,9 @@ def _rec(**kwargs: typing.Any) -> Record:
def _build_tool_spec(ts: ToolSpec) -> Record:
- return _rec(name=ts.name, description=ts.description, **{"input-schema": ts.input_schema})
+ return _rec(
+ name=ts.name, description=ts.description, **{"input-schema": ts.input_schema}
+ )
def _build_model_config_variant(cfg: ModelConfigInput) -> Variant:
@@ -266,7 +271,9 @@ def _build_conversation_manager_variant(
"should-truncate-results": False,
"summary-ratio": config.get("summary_ratio"),
"preserve-recent-messages": config.get("preserve_recent_messages"),
- "summarization-system-prompt": config.get("summarization_system_prompt"),
+ "summarization-system-prompt": config.get(
+ "summarization_system_prompt"
+ ),
"summarization-model-config": config.get("summarization_model_config"),
},
)
@@ -319,6 +326,7 @@ def _build_stream_args(
# Variant → flat StreamEvent converters (WIT → Python types)
# ---------------------------------------------------------------------------
+
def _opt_attr(rec: typing.Any, name: str) -> typing.Any:
"""Read an optional attribute from a wasmtime Record (kebab-case)."""
return getattr(rec, name, None) if rec is not None else None
@@ -413,6 +421,7 @@ def _convert_stream_event(v: Variant) -> StreamEvent:
# AWS credential injection
# ---------------------------------------------------------------------------
+
def _resolve_aws_credentials() -> tuple[str, str, str | None] | None:
key_id = os.environ.get("AWS_ACCESS_KEY_ID")
secret = os.environ.get("AWS_SECRET_ACCESS_KEY")
@@ -485,7 +494,10 @@ def _inject_aws_credentials_default() -> Variant | None:
# Import callback factories
# ---------------------------------------------------------------------------
-def _make_call_tool_fn(dispatcher: ToolDispatcherBase | None) -> typing.Callable[..., typing.Any]:
+
+def _make_call_tool_fn(
+ dispatcher: ToolDispatcherBase | None,
+) -> typing.Callable[..., typing.Any]:
def call_tool(store_ctx: typing.Any, args: typing.Any) -> Variant:
name = getattr(args, "name")
input_json = getattr(args, "input")
@@ -497,10 +509,13 @@ def call_tool(store_ctx: typing.Any, args: typing.Any) -> Variant:
return Variant("ok", result)
except Exception as exc:
return Variant("err", str(exc))
+
return call_tool
-def _make_call_tools_fn(dispatcher: ToolDispatcherBase | None) -> typing.Callable[..., typing.Any]:
+def _make_call_tools_fn(
+ dispatcher: ToolDispatcherBase | None,
+) -> typing.Callable[..., typing.Any]:
def call_tools(store_ctx: typing.Any, args: typing.Any) -> list[Variant]:
calls = getattr(args, "calls")
results: list[Variant] = []
@@ -517,6 +532,7 @@ def call_tools(store_ctx: typing.Any, args: typing.Any) -> list[Variant]:
except Exception as exc:
results.append(Variant("err", str(exc)))
return results
+
return call_tools
@@ -529,18 +545,95 @@ def log_fn(store_ctx: typing.Any, entry: typing.Any) -> None:
handler.log(level, message, context)
else:
logger = logging.getLogger("strands.wasm")
- py_level = {"error": 40, "warn": 30, "info": 20, "debug": 10, "trace": 10}.get(
- level, 20
- )
+ py_level = {
+ "error": 40,
+ "warn": 30,
+ "info": 20,
+ "debug": 10,
+ "trace": 10,
+ }.get(level, 20)
msg = f"{message} | {context}" if context else message
logger.log(py_level, msg)
+
return log_fn
+# ---------------------------------------------------------------------------
+# Hook-provider no-op callbacks (return zero-value decisions)
+# ---------------------------------------------------------------------------
+
+
+def _hook_get_capabilities(store_ctx: typing.Any) -> list[str]:
+ return []
+
+
+def _cancel_decision() -> Record:
+ """Default no-op cancellable decision."""
+ return _rec(cancel=False, **{"cancel-message": None})
+
+
+def _hook_before_invocation(store_ctx: typing.Any) -> Record:
+ return _cancel_decision()
+
+
+def _hook_after_invocation(store_ctx: typing.Any) -> Record:
+ return _rec(resume=None)
+
+
+def _hook_before_model_call(
+ store_ctx: typing.Any, projected_input_tokens: typing.Optional[int]
+) -> Record:
+ return _cancel_decision()
+
+
+def _hook_after_model_call(
+ store_ctx: typing.Any,
+ stop_reason: typing.Optional[str],
+ stop_data: typing.Optional[str],
+ error: typing.Optional[str],
+) -> Record:
+ return _rec(retry=False)
+
+
+def _hook_before_tools(store_ctx: typing.Any, message: str) -> Record:
+ return _cancel_decision()
+
+
+def _hook_after_tools(store_ctx: typing.Any) -> Record:
+ return _rec()
+
+
+def _hook_before_tool_call(store_ctx: typing.Any, tool_use: str) -> Record:
+ return _rec(
+ cancel=False,
+ **{"cancel-message": None, "tool-use": None, "selected-tool-name": None},
+ )
+
+
+def _hook_after_tool_call(
+ store_ctx: typing.Any, tool_use: str, result: str, error: typing.Optional[str]
+) -> Record:
+ return _rec(retry=False, result=None)
+
+
+_HOOK_PROVIDER_FUNCS: list[tuple[str, typing.Callable]] = [
+ ("get-capabilities", _hook_get_capabilities),
+ ("before-invocation", _hook_before_invocation),
+ ("after-invocation", _hook_after_invocation),
+ ("before-model-call", _hook_before_model_call),
+ ("after-model-call", _hook_after_model_call),
+ ("before-tools", _hook_before_tools),
+ ("after-tools", _hook_after_tools),
+ ("before-tool-call", _hook_before_tool_call),
+ ("after-tool-call", _hook_after_tool_call),
+]
+
+
# ---------------------------------------------------------------------------
# WasmAgent — drop-in replacement for the former native Agent class
# ---------------------------------------------------------------------------
+
class WasmAgent:
"""WASM-hosted agent with the same API as the former native ``Agent``."""
@@ -568,6 +661,9 @@ def __init__(
tp.add_func("call-tools", _make_call_tools_fn(tool_dispatcher))
with root.add_instance("strands:agent/host-log") as hl:
hl.add_func("log", _make_log_fn(log_handler))
+ with root.add_instance("strands:agent/hook-provider") as hp:
+ for name, fn in _HOOK_PROVIDER_FUNCS:
+ hp.add_func(name, fn)
# --- store ---
store = Store(engine)
@@ -584,7 +680,13 @@ def __init__(
self._component = component
# --- instantiate + construct agent (async, run synchronously) ---
- agent_config = _build_agent_config(model, system_prompt, system_prompt_blocks, tools, conversation_manager_config)
+ agent_config = _build_agent_config(
+ model,
+ system_prompt,
+ system_prompt_blocks,
+ tools,
+ conversation_manager_config,
+ )
_run_sync(self._init_async(linker, store, component, agent_config))
async def _init_async(
@@ -622,9 +724,7 @@ def _fn(name: str) -> Func:
async def start_stream(self, input_text: str) -> typing.Any:
args = _build_stream_args(input_text, None, None)
- return await self._generate_fn.call_async(
- self._store, self._agent_handle, args
- )
+ return await self._generate_fn.call_async(self._store, self._agent_handle, args)
async def start_stream_with_options(
self,
@@ -633,13 +733,9 @@ async def start_stream_with_options(
tool_choice: str | None,
) -> typing.Any:
args = _build_stream_args(input_text, tools, tool_choice)
- return await self._generate_fn.call_async(
- self._store, self._agent_handle, args
- )
+ return await self._generate_fn.call_async(self._store, self._agent_handle, args)
- async def next_events(
- self, stream_handle: typing.Any
- ) -> list[StreamEvent] | None:
+ async def next_events(self, stream_handle: typing.Any) -> list[StreamEvent] | None:
raw = await self._read_next_fn.call_async(self._store, stream_handle)
if raw is None:
return None
diff --git a/strands-py/tests/__init__.py b/strands-py/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/strands-py/tests/test_hook_provider_stubs.py b/strands-py/tests/test_hook_provider_stubs.py
new file mode 100644
index 000000000..65ae43852
--- /dev/null
+++ b/strands-py/tests/test_hook_provider_stubs.py
@@ -0,0 +1,113 @@
+"""Unit tests for hook-provider no-op stubs in _wasm_host.py."""
+
+from strands._wasm_host import (
+ _cancel_decision,
+ _hook_get_capabilities,
+ _hook_before_invocation,
+ _hook_after_invocation,
+ _hook_before_model_call,
+ _hook_after_model_call,
+ _hook_before_tools,
+ _hook_after_tools,
+ _hook_before_tool_call,
+ _hook_after_tool_call,
+ _HOOK_PROVIDER_FUNCS,
+)
+
+
+class TestCancelDecision:
+ def test_produces_correct_record_shape(self):
+ result = _cancel_decision()
+ assert result.cancel is False
+ assert getattr(result, "cancel-message") is None
+
+
+class TestHookBeforeToolCall:
+ def test_returns_all_four_decision_fields(self):
+ result = _hook_before_tool_call(None, "{}")
+ assert result.cancel is False
+ assert getattr(result, "cancel-message") is None
+ assert getattr(result, "tool-use") is None
+ assert getattr(result, "selected-tool-name") is None
+
+
+class TestHookAfterToolCall:
+ def test_returns_retry_and_result_fields(self):
+ result = _hook_after_tool_call(None, "{}", "{}", None)
+ assert result.retry is False
+ assert result.result is None
+
+
+class TestHookAfterModelCall:
+ def test_accepts_all_none_params(self):
+ result = _hook_after_model_call(None, None, None, None)
+ assert result.retry is False
+
+ def test_accepts_string_params(self):
+ result = _hook_after_model_call(
+ None, "endTurn", '{"reason":"end-turn"}', "something failed"
+ )
+ assert result.retry is False
+
+
+class TestHookProviderFuncsList:
+ def test_completeness(self):
+ assert len(_HOOK_PROVIDER_FUNCS) == 9
+
+ def test_names(self):
+ names = [name for name, _ in _HOOK_PROVIDER_FUNCS]
+ assert names == [
+ "get-capabilities",
+ "before-invocation",
+ "after-invocation",
+ "before-model-call",
+ "after-model-call",
+ "before-tools",
+ "after-tools",
+ "before-tool-call",
+ "after-tool-call",
+ ]
+
+
+class TestHookGetCapabilities:
+ def test_returns_empty_list(self):
+ result = _hook_get_capabilities(None)
+ assert result == []
+
+
+class TestHookBeforeInvocation:
+ def test_returns_cancel_decision(self):
+ result = _hook_before_invocation(None)
+ assert result.cancel is False
+ assert getattr(result, "cancel-message") is None
+
+
+class TestHookAfterInvocation:
+ def test_returns_resume_none(self):
+ result = _hook_after_invocation(None)
+ assert result.resume is None
+
+
+class TestHookBeforeModelCall:
+ def test_returns_cancel_decision_with_none(self):
+ result = _hook_before_model_call(None, None)
+ assert result.cancel is False
+ assert getattr(result, "cancel-message") is None
+
+ def test_returns_cancel_decision_with_int(self):
+ result = _hook_before_model_call(None, 42)
+ assert result.cancel is False
+ assert getattr(result, "cancel-message") is None
+
+
+class TestHookBeforeTools:
+ def test_returns_cancel_decision(self):
+ result = _hook_before_tools(None, '{"role":"assistant","content":[]}')
+ assert result.cancel is False
+ assert getattr(result, "cancel-message") is None
+
+
+class TestHookAfterTools:
+ def test_returns_empty_record(self):
+ result = _hook_after_tools(None)
+ assert result.__dict__ == {}
diff --git a/strands-wasm/__fixtures__/hook-provider.ts b/strands-wasm/__fixtures__/hook-provider.ts
new file mode 100644
index 000000000..a2b1dcd0c
--- /dev/null
+++ b/strands-wasm/__fixtures__/hook-provider.ts
@@ -0,0 +1,28 @@
+import { vi } from 'vitest'
+import { FunctionTool, type FunctionToolCallback } from '@strands-agents/sdk'
+export const getCapabilities = vi.fn().mockReturnValue([])
+export const beforeInvocation = vi.fn().mockReturnValue({ cancel: false, cancelMessage: undefined })
+export const afterInvocation = vi.fn().mockReturnValue({ resume: undefined })
+export const beforeModelCall = vi.fn().mockReturnValue({ cancel: false, cancelMessage: undefined })
+export const afterModelCall = vi.fn().mockReturnValue({ retry: false })
+export const beforeTools = vi.fn().mockReturnValue({ cancel: false, cancelMessage: undefined })
+export const afterTools = vi.fn().mockReturnValue({})
+export const beforeToolCall = vi.fn().mockReturnValue({
+ cancel: false,
+ cancelMessage: undefined,
+ toolUse: undefined,
+ selectedToolName: undefined,
+})
+export const afterToolCall = vi.fn().mockReturnValue({ retry: false, result: undefined })
+
+/** Factory for test tools with sensible defaults. */
+export function testTool(
+ overrides: { name?: string; description?: string; callback?: FunctionToolCallback } = {}
+): FunctionTool {
+ return new FunctionTool({
+ name: overrides.name ?? 'test_tool',
+ description: overrides.description ?? 'test',
+ inputSchema: { type: 'object', properties: {} },
+ callback: overrides.callback ?? (() => [{ text: 'ok' }]),
+ })
+}
diff --git a/strands-wasm/__tests__/hook-provider-bridge.test.ts b/strands-wasm/__tests__/hook-provider-bridge.test.ts
new file mode 100644
index 000000000..44fa6f2d1
--- /dev/null
+++ b/strands-wasm/__tests__/hook-provider-bridge.test.ts
@@ -0,0 +1,642 @@
+import { describe, it, expect, beforeEach, vi } from 'vitest'
+import { HookProviderBridge } from '../entry'
+import { Agent, ToolResultBlock, TextBlock } from '@strands-agents/sdk'
+import { MockMessageModel } from '$/fixtures/mock-message-model'
+import * as hookProvider from '../__fixtures__/hook-provider'
+
+describe('HookProviderBridge', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ })
+
+ describe('capability negotiation', () => {
+ it('registers no hooks when capabilities list is empty', async () => {
+ hookProvider.getCapabilities.mockReturnValue([])
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'hi' })
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+ await agent.invoke('hello')
+
+ expect(hookProvider.beforeInvocation).not.toHaveBeenCalled()
+ expect(hookProvider.afterInvocation).not.toHaveBeenCalled()
+ expect(hookProvider.beforeModelCall).not.toHaveBeenCalled()
+ expect(hookProvider.afterModelCall).not.toHaveBeenCalled()
+ })
+
+ it('registers only declared capabilities', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-model-call', 'after-model-call'])
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'hi' })
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+ await agent.invoke('hello')
+
+ expect(hookProvider.beforeModelCall).toHaveBeenCalled()
+ expect(hookProvider.afterModelCall).toHaveBeenCalled()
+ expect(hookProvider.beforeInvocation).not.toHaveBeenCalled()
+ expect(hookProvider.afterInvocation).not.toHaveBeenCalled()
+ })
+
+ it('calls after-tools hook when capability is declared and tools run', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-tools'])
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+ const tool = hookProvider.testTool()
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(hookProvider.afterTools).toHaveBeenCalled()
+ })
+ })
+
+ describe('cancel decisions', () => {
+ it('cancels invocation when before-invocation returns cancel=true', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-invocation'])
+ hookProvider.beforeInvocation.mockReturnValue({ cancel: true, cancelMessage: 'stopped by host' })
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'hi' })
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+ const result = await agent.invoke('hello')
+
+ expect(hookProvider.beforeInvocation).toHaveBeenCalled()
+ expect(result.stopReason).toBe('endTurn')
+ })
+
+ it('cancels model call when before-model-call returns cancel=true', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-model-call'])
+ hookProvider.beforeModelCall.mockReturnValue({ cancel: true, cancelMessage: undefined })
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'hi' })
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+ const result = await agent.invoke('hello')
+
+ expect(hookProvider.beforeModelCall).toHaveBeenCalled()
+ expect(result.stopReason).toBe('endTurn')
+ })
+ })
+
+ describe('JSON parse safety', () => {
+ it('handles invalid JSON in resume field gracefully', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-invocation'])
+ hookProvider.afterInvocation.mockReturnValue({ resume: 'not-valid-json{' })
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'hi' })
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+
+ // Should not throw — invalid JSON is logged and ignored
+ await expect(agent.invoke('hello')).resolves.toBeDefined()
+ })
+
+ it('handles invalid JSON in toolUse field gracefully', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-tool-call'])
+ hookProvider.beforeToolCall.mockReturnValue({
+ cancel: false,
+ cancelMessage: undefined,
+ toolUse: 'invalid{json',
+ selectedToolName: undefined,
+ })
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+ let receivedInput: unknown
+ const tool = hookProvider.testTool({
+ callback: (input) => {
+ receivedInput = input
+ return [{ text: 'ok' }]
+ },
+ })
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+
+ await expect(agent.invoke('use tool')).resolves.toBeDefined()
+ expect(receivedInput).toStrictEqual({})
+ })
+
+ it('handles invalid JSON in result field gracefully', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-tool-call'])
+ hookProvider.afterToolCall.mockReturnValue({ retry: false, result: 'bad-json{{' })
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+ const tool = hookProvider.testTool()
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+
+ await expect(agent.invoke('use tool')).resolves.toBeDefined()
+ const toolResult = agent.messages
+ .flatMap((m) => m.content)
+ .find((b) => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1')
+ expect(toolResult).toBeDefined()
+ })
+ })
+
+ describe('toolUse rewrite', () => {
+ it('tool receives rewritten input from before-tool-call decision', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-tool-call'])
+ hookProvider.beforeToolCall.mockReturnValue({
+ cancel: false,
+ cancelMessage: undefined,
+ toolUse: JSON.stringify({ name: 'test_tool', toolUseId: 'tu-1', input: { rewritten: true } }),
+ selectedToolName: undefined,
+ })
+
+ let receivedInput: unknown
+ const tool = hookProvider.testTool({
+ callback: (input) => {
+ receivedInput = input
+ return [{ text: 'ok' }]
+ },
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: { original: true } })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(receivedInput).toStrictEqual({ rewritten: true })
+ })
+
+ it('partial rewrite preserves fields not present in host JSON', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-tool-call'])
+ hookProvider.beforeToolCall.mockReturnValue({
+ cancel: false,
+ cancelMessage: undefined,
+ toolUse: JSON.stringify({ input: { injected: 'context' } }),
+ selectedToolName: undefined,
+ })
+
+ let receivedInput: unknown
+ const tool = hookProvider.testTool({
+ callback: (input) => {
+ receivedInput = input
+ return [{ text: 'ok' }]
+ },
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-original', input: { original: true } })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(receivedInput).toStrictEqual({ injected: 'context' })
+ const toolUseMessage = agent.messages.find((m) => m.content.some((b) => b.type === 'toolUseBlock'))
+ const toolUseBlock = toolUseMessage!.content.find((b) => b.type === 'toolUseBlock')!
+ expect(toolUseBlock.name).toBe('test_tool')
+ expect(toolUseBlock.toolUseId).toBe('tu-original')
+ })
+
+ it('empty object rewrite preserves all original fields', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-tool-call'])
+ hookProvider.beforeToolCall.mockReturnValue({
+ cancel: false,
+ cancelMessage: undefined,
+ toolUse: JSON.stringify({}),
+ selectedToolName: undefined,
+ })
+
+ let receivedInput: unknown
+ const tool = hookProvider.testTool({
+ callback: (input) => {
+ receivedInput = input
+ return [{ text: 'ok' }]
+ },
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-keep', input: { keep: 'me' } })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(receivedInput).toStrictEqual({ keep: 'me' })
+ const toolUseMessage = agent.messages.find((m) => m.content.some((b) => b.type === 'toolUseBlock'))
+ const toolUseBlock = toolUseMessage!.content.find((b) => b.type === 'toolUseBlock')!
+ expect(toolUseBlock.name).toBe('test_tool')
+ expect(toolUseBlock.toolUseId).toBe('tu-keep')
+ })
+ })
+
+ describe('selectedToolName', () => {
+ it('executes the replacement tool from registry', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-tool-call'])
+ hookProvider.beforeToolCall.mockReturnValue({
+ cancel: false,
+ cancelMessage: undefined,
+ toolUse: undefined,
+ selectedToolName: 'replacement_tool',
+ })
+
+ let originalExecuted = false
+ let replacementExecuted = false
+
+ const originalTool = hookProvider.testTool({
+ name: 'original_tool',
+ description: 'original',
+ callback: () => {
+ originalExecuted = true
+ return [{ text: 'original' }]
+ },
+ })
+ const replacementTool = hookProvider.testTool({
+ name: 'replacement_tool',
+ description: 'replacement',
+ callback: () => {
+ replacementExecuted = true
+ return [{ text: 'replaced' }]
+ },
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'original_tool', toolUseId: 'tu-1', input: {} })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [originalTool, replacementTool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(originalExecuted).toBe(false)
+ expect(replacementExecuted).toBe(true)
+ })
+
+ it('selectedToolName combined with toolUse rewrite applies both', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-tool-call'])
+ hookProvider.beforeToolCall.mockReturnValue({
+ cancel: false,
+ cancelMessage: undefined,
+ toolUse: JSON.stringify({ input: { redirected: true } }),
+ selectedToolName: 'replacement_tool',
+ })
+
+ let replacementInput: unknown
+ const originalTool = hookProvider.testTool({
+ name: 'original_tool',
+ description: 'original',
+ callback: () => [{ text: 'original' }],
+ })
+ const replacementTool = hookProvider.testTool({
+ name: 'replacement_tool',
+ description: 'replacement',
+ callback: (input) => {
+ replacementInput = input
+ return [{ text: 'replaced' }]
+ },
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'original_tool', toolUseId: 'tu-1', input: { original: true } })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [originalTool, replacementTool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(replacementInput).toStrictEqual({ redirected: true })
+ })
+ })
+
+ describe('result replacement', () => {
+ it('replaces tool result in conversation history via after-tool-call', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-tool-call'])
+ hookProvider.afterToolCall.mockReturnValue({
+ retry: false,
+ result: JSON.stringify({
+ toolResult: {
+ toolUseId: 'tu-1',
+ status: 'success',
+ content: [{ text: '[REDACTED]' }],
+ },
+ }),
+ })
+
+ const tool = hookProvider.testTool({
+ callback: () => [{ text: 'SECRET_VALUE' }],
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+ await agent.invoke('use tool')
+
+ const toolResultMessage = agent.messages.find((m) =>
+ m.content.some((b) => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1')
+ )
+ expect(toolResultMessage).toBeDefined()
+ const block = toolResultMessage!.content.find(
+ (b): b is ToolResultBlock => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1'
+ )
+ expect(block!.content[0]).toStrictEqual(new TextBlock('[REDACTED]'))
+ })
+
+ it('handles result JSON missing toolResult wrapper gracefully', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-tool-call'])
+ // Missing the { toolResult: ... } wrapper — fromJSON will throw, code should catch and keep original
+ hookProvider.afterToolCall.mockReturnValue({
+ retry: false,
+ result: JSON.stringify({ toolUseId: 'tu-1', status: 'success', content: [{ text: 'flat' }] }),
+ })
+
+ const tool = hookProvider.testTool({
+ callback: () => [{ text: 'original output' }],
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+
+ // Should not throw — malformed result is logged and original result preserved
+ await expect(agent.invoke('use tool')).resolves.toBeDefined()
+
+ const toolResultMessage = agent.messages.find((m) =>
+ m.content.some((b) => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1')
+ )
+ expect(toolResultMessage).toBeDefined()
+ const block = toolResultMessage!.content.find(
+ (b): b is ToolResultBlock => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1'
+ )
+ // Original result preserved since replacement JSON was malformed
+ expect(block).toBeDefined()
+ })
+ })
+
+ describe('retry', () => {
+ it('after-tool-call retry re-executes the tool', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-tool-call'])
+ let callCount = 0
+ hookProvider.afterToolCall.mockImplementation(() => {
+ callCount++
+ if (callCount === 1) return { retry: true, result: undefined }
+ return { retry: false, result: undefined }
+ })
+
+ let toolCallCount = 0
+ const tool = hookProvider.testTool({
+ callback: () => {
+ toolCallCount++
+ return [{ text: `call-${toolCallCount}` }]
+ },
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(toolCallCount).toBe(2)
+ })
+
+ it('after-model-call retry re-invokes the model', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-model-call'])
+ let hookCallCount = 0
+ hookProvider.afterModelCall.mockImplementation(() => {
+ hookCallCount++
+ if (hookCallCount === 1) return { retry: true }
+ return { retry: false }
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'first' })
+ model.addTurn({ type: 'textBlock', text: 'second' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+ await agent.invoke('hello')
+
+ expect(model.callCount).toBe(2)
+ })
+
+ it('after-model-call consecutive retries all re-invoke the model', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-model-call'])
+ let hookCallCount = 0
+ hookProvider.afterModelCall.mockImplementation(() => {
+ hookCallCount++
+ if (hookCallCount <= 3) return { retry: true }
+ return { retry: false }
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'attempt-1' })
+ model.addTurn({ type: 'textBlock', text: 'attempt-2' })
+ model.addTurn({ type: 'textBlock', text: 'attempt-3' })
+ model.addTurn({ type: 'textBlock', text: 'final' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+ await agent.invoke('hello')
+
+ expect(model.callCount).toBe(4)
+ expect(hookCallCount).toBe(4)
+ })
+ })
+
+ describe('resume', () => {
+ it('after-invocation resume re-enters the agent loop with new input', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-invocation'])
+ let invocationCount = 0
+ hookProvider.afterInvocation.mockImplementation(() => {
+ invocationCount++
+ if (invocationCount === 1) return { resume: JSON.stringify('follow-up question') }
+ return { resume: undefined }
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'first response' })
+ model.addTurn({ type: 'textBlock', text: 'second response' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+ await agent.invoke('initial')
+
+ expect(model.callCount).toBe(2)
+ const userMessages = agent.messages.filter((m) => m.role === 'user')
+ expect(userMessages).toHaveLength(2)
+ })
+
+ it('after-invocation resume works with content block array JSON', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-invocation'])
+ let invocationCount = 0
+ const resumePayload = [{ role: 'user', content: [{ type: 'textBlock', text: 'structured follow-up' }] }]
+ hookProvider.afterInvocation.mockImplementation(() => {
+ invocationCount++
+ if (invocationCount === 1) {
+ return { resume: JSON.stringify(resumePayload) }
+ }
+ return { resume: undefined }
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'first response' })
+ model.addTurn({ type: 'textBlock', text: 'second response' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+ await agent.invoke('initial')
+
+ expect(model.callCount).toBe(2)
+ })
+
+ it('after-invocation resume works with array JSON', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-invocation'])
+ let invocationCount = 0
+ hookProvider.afterInvocation.mockImplementation(() => {
+ invocationCount++
+ if (invocationCount === 1) return { resume: JSON.stringify([{ type: 'textBlock', text: 'follow-up' }]) }
+ return { resume: undefined }
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'first response' })
+ model.addTurn({ type: 'textBlock', text: 'second response' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+ await agent.invoke('initial')
+
+ expect(model.callCount).toBe(2)
+ const userMessages = agent.messages.filter((m) => m.role === 'user')
+ expect(userMessages).toHaveLength(2)
+ })
+ })
+
+ describe('before-tools cancel', () => {
+ it('cancels all tool calls in the batch', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-tools'])
+ hookProvider.beforeTools.mockReturnValue({ cancel: true, cancelMessage: 'tools disabled' })
+
+ let toolExecuted = false
+ const tool = hookProvider.testTool({
+ callback: () => {
+ toolExecuted = true
+ return [{ text: 'ok' }]
+ },
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(toolExecuted).toBe(false)
+ })
+
+ it('cancels multiple tool calls in the same turn', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-tools'])
+ hookProvider.beforeTools.mockReturnValue({ cancel: true, cancelMessage: 'batch blocked' })
+
+ let toolAExecuted = false
+ let toolBExecuted = false
+ const toolA = hookProvider.testTool({
+ name: 'tool_a',
+ description: 'tool A',
+ callback: () => {
+ toolAExecuted = true
+ return [{ text: 'a' }]
+ },
+ })
+ const toolB = hookProvider.testTool({
+ name: 'tool_b',
+ description: 'tool B',
+ callback: () => {
+ toolBExecuted = true
+ return [{ text: 'b' }]
+ },
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn([
+ { type: 'toolUseBlock', name: 'tool_a', toolUseId: 'tu-a', input: {} },
+ { type: 'toolUseBlock', name: 'tool_b', toolUseId: 'tu-b', input: {} },
+ ])
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [toolA, toolB], printer: false })
+ await agent.invoke('use tools')
+
+ expect(toolAExecuted).toBe(false)
+ expect(toolBExecuted).toBe(false)
+ })
+ })
+
+ describe('cancel message propagation', () => {
+ it('before-invocation cancel message appears in response', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-invocation'])
+ hookProvider.beforeInvocation.mockReturnValue({ cancel: true, cancelMessage: 'blocked by policy' })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'should not reach' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+ const result = await agent.invoke('hello')
+
+ const lastAssistant = agent.messages.findLast((m) => m.role === 'assistant')
+ expect(lastAssistant).toBeDefined()
+ const textBlock = lastAssistant!.content.find((b) => b.type === 'textBlock')
+ expect(textBlock).toBeDefined()
+ expect((textBlock as TextBlock).text).toBe('blocked by policy')
+ expect(model.callCount).toBe(0)
+ })
+
+ it('before-tool-call cancel message becomes tool result error', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-tool-call'])
+ hookProvider.beforeToolCall.mockReturnValue({
+ cancel: true,
+ cancelMessage: 'tool blocked',
+ toolUse: undefined,
+ selectedToolName: undefined,
+ })
+
+ let toolExecuted = false
+ const tool = hookProvider.testTool({
+ callback: () => {
+ toolExecuted = true
+ return [{ text: 'ok' }]
+ },
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(toolExecuted).toBe(false)
+ const toolResultMessage = agent.messages.find((m) =>
+ m.content.some((b) => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1')
+ )
+ expect(toolResultMessage).toBeDefined()
+ const block = toolResultMessage!.content.find(
+ (b): b is ToolResultBlock => b.type === 'toolResultBlock' && b.toolUseId === 'tu-1'
+ )
+ expect(block!.status).toBe('error')
+ })
+ })
+})
diff --git a/strands-wasm/__tests__/hook-provider-forwarding.test.ts b/strands-wasm/__tests__/hook-provider-forwarding.test.ts
new file mode 100644
index 000000000..eb4b359de
--- /dev/null
+++ b/strands-wasm/__tests__/hook-provider-forwarding.test.ts
@@ -0,0 +1,154 @@
+import { describe, it, expect, beforeEach, vi } from 'vitest'
+import { HookProviderBridge } from '../entry'
+import { Agent, FunctionTool } from '@strands-agents/sdk'
+import { MockMessageModel } from '$/fixtures/mock-message-model'
+import * as hookProvider from '../__fixtures__/hook-provider'
+
+describe('HookProviderBridge argument forwarding', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ })
+
+ describe('before-model-call', () => {
+ it('forwards projectedInputTokens to host function', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-model-call'])
+ hookProvider.beforeModelCall.mockReturnValue({ cancel: false, cancelMessage: undefined })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'hi' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+ await agent.invoke('hello')
+
+ expect(hookProvider.beforeModelCall).toHaveBeenCalled()
+ const firstArg = hookProvider.beforeModelCall.mock.calls[0][0]
+ expect(firstArg === undefined || typeof firstArg === 'number').toBe(true)
+ })
+ })
+
+ describe('after-model-call', () => {
+ it('forwards stopReason on successful completion', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-model-call'])
+ hookProvider.afterModelCall.mockReturnValue({ retry: false })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'hi' })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+ await agent.invoke('hello')
+
+ expect(hookProvider.afterModelCall).toHaveBeenCalled()
+ const [stopReason, , error] = hookProvider.afterModelCall.mock.calls[0]
+ expect(stopReason).toBe('endTurn')
+ expect(error).toBeUndefined()
+ })
+
+ it('forwards error message on model failure', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-model-call'])
+ hookProvider.afterModelCall.mockReturnValue({ retry: false })
+
+ const model = new MockMessageModel()
+ model.addTurn(new Error('model exploded'))
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], printer: false })
+
+ await expect(agent.invoke('hello')).rejects.toThrow()
+
+ expect(hookProvider.afterModelCall).toHaveBeenCalled()
+ const [, , error] = hookProvider.afterModelCall.mock.calls[0]
+ expect(error).toBe('model exploded')
+ })
+ })
+
+ describe('before-tools', () => {
+ it('forwards serialized message to host function', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-tools'])
+ hookProvider.beforeTools.mockReturnValue({ cancel: false, cancelMessage: undefined })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'my_tool', toolUseId: 'tu-1', input: { x: 1 } })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const tool = new FunctionTool({
+ name: 'my_tool',
+ description: 'test tool',
+ inputSchema: { type: 'object', properties: {} },
+ callback: () => [{ text: 'ok' }],
+ })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(hookProvider.beforeTools).toHaveBeenCalled()
+ const firstArg = hookProvider.beforeTools.mock.calls[0][0]
+ expect(typeof firstArg).toBe('string')
+ const parsed = JSON.parse(firstArg as string)
+ expect(parsed).toBeDefined()
+ expect(typeof parsed).toBe('object')
+ })
+ })
+
+ describe('after-tool-call', () => {
+ it('forwards serialized toolUse, result, and error', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-tool-call'])
+ hookProvider.afterToolCall.mockReturnValue({ retry: false, result: undefined })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'my_tool', toolUseId: 'tu-123', input: { key: 'value' } })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const tool = new FunctionTool({
+ name: 'my_tool',
+ description: 'test tool',
+ inputSchema: { type: 'object', properties: {} },
+ callback: () => [{ text: 'tool output' }],
+ })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(hookProvider.afterToolCall).toHaveBeenCalled()
+ const [toolUseJson, resultJson, error] = hookProvider.afterToolCall.mock.calls[0]
+
+ const toolUse = JSON.parse(toolUseJson as string)
+ expect(toolUse.name).toBe('my_tool')
+ expect(toolUse.toolUseId).toBe('tu-123')
+
+ const result = JSON.parse(resultJson as string)
+ expect(typeof result).toBe('object')
+
+ expect(error).toBeUndefined()
+ })
+
+ it('forwards error string when tool throws', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['after-tool-call'])
+ hookProvider.afterToolCall.mockReturnValue({ retry: false, result: undefined })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'error_tool', toolUseId: 'tu-err', input: {} })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const tool = new FunctionTool({
+ name: 'error_tool',
+ description: 'tool that throws',
+ inputSchema: { type: 'object', properties: {} },
+ callback: () => {
+ throw new Error('tool crashed')
+ },
+ })
+
+ const bridge = new HookProviderBridge()
+ const agent = new Agent({ model, plugins: [bridge], tools: [tool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(hookProvider.afterToolCall).toHaveBeenCalled()
+ const [, , error] = hookProvider.afterToolCall.mock.calls[0]
+ expect(error).toBe('tool crashed')
+ })
+ })
+})
diff --git a/strands-wasm/__tests__/hook-provider-regression.test.ts b/strands-wasm/__tests__/hook-provider-regression.test.ts
new file mode 100644
index 000000000..20f97b92b
--- /dev/null
+++ b/strands-wasm/__tests__/hook-provider-regression.test.ts
@@ -0,0 +1,161 @@
+import { describe, it, expect, beforeEach, vi } from 'vitest'
+import { HookProviderBridge, LifecycleBridge } from '../entry'
+import { Agent } from '@strands-agents/sdk'
+import { MockMessageModel } from '$/fixtures/mock-message-model'
+import * as hookProvider from '../__fixtures__/hook-provider'
+import { testTool } from '../__fixtures__/hook-provider'
+
+describe('HookProviderBridge regression', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ })
+
+ function expectNoHooksCalled(): void {
+ expect(hookProvider.beforeInvocation).not.toHaveBeenCalled()
+ expect(hookProvider.afterInvocation).not.toHaveBeenCalled()
+ expect(hookProvider.beforeModelCall).not.toHaveBeenCalled()
+ expect(hookProvider.afterModelCall).not.toHaveBeenCalled()
+ expect(hookProvider.beforeTools).not.toHaveBeenCalled()
+ expect(hookProvider.afterTools).not.toHaveBeenCalled()
+ expect(hookProvider.beforeToolCall).not.toHaveBeenCalled()
+ expect(hookProvider.afterToolCall).not.toHaveBeenCalled()
+ }
+
+ it('text-only invocation unchanged with HookProviderBridge present', async () => {
+ hookProvider.getCapabilities.mockReturnValue([])
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'hello' })
+ model.addTurn({ type: 'textBlock', text: 'unreachable' })
+ const agent = new Agent({ model, plugins: [new HookProviderBridge()], printer: false })
+ const result = await agent.invoke('hi')
+
+ expect(hookProvider.getCapabilities).toHaveBeenCalledTimes(1)
+ expectNoHooksCalled()
+ expect(model.callCount).toBe(1)
+ const textBlock = result.lastMessage.content.find((b) => b.type === 'textBlock')
+ expect(textBlock).toBeDefined()
+ expect(textBlock!.text).toBe('hello')
+ })
+
+ it('tool invocation unchanged with HookProviderBridge present', async () => {
+ hookProvider.getCapabilities.mockReturnValue([])
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ let toolCalled = false
+ const tool = testTool({
+ callback: () => {
+ toolCalled = true
+ return [{ text: 'ok' }]
+ },
+ })
+
+ const agent = new Agent({ model, plugins: [new HookProviderBridge()], tools: [tool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(toolCalled).toBe(true)
+ expect(model.callCount).toBe(2)
+ expect(hookProvider.getCapabilities).toHaveBeenCalledTimes(1)
+ expectNoHooksCalled()
+ })
+
+ it('LifecycleBridge events unaffected by HookProviderBridge coexistence', async () => {
+ hookProvider.getCapabilities.mockReturnValue([])
+ const lifecycle = new LifecycleBridge()
+ const hookBridge = new HookProviderBridge()
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'Hello' })
+ const agent = new Agent({ model, plugins: [lifecycle, hookBridge], printer: false })
+ await agent.invoke('hi')
+
+ const events = lifecycle.drain()
+ const eventTypes = events.map((e) => e.val.eventType)
+
+ expect(eventTypes).toContain('initialized')
+ expect(eventTypes).toContain('before-invocation')
+ expect(eventTypes).toContain('before-model-call')
+ expect(eventTypes).toContain('after-model-call')
+ expect(eventTypes).toContain('message-added')
+ expect(eventTypes).toContain('after-invocation')
+ })
+
+ it('LifecycleBridge before-tools/after-tools events with both plugins during tool turn', async () => {
+ hookProvider.getCapabilities.mockReturnValue([])
+ const lifecycle = new LifecycleBridge()
+ const hookBridge = new HookProviderBridge()
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'test_tool', toolUseId: 'tu-1', input: {} })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const tool = testTool()
+
+ const agent = new Agent({ model, plugins: [lifecycle, hookBridge], tools: [tool], printer: false })
+ await agent.invoke('use tool')
+
+ const events = lifecycle.drain()
+ const eventTypes = events.map((e) => e.val.eventType)
+
+ expect(eventTypes).toContain('before-tools')
+ expect(eventTypes).toContain('after-tools')
+ expect(eventTypes).toContain('before-tool-call')
+ expect(eventTypes).toContain('after-tool-call')
+ })
+})
+
+describe('HookProviderBridge error paths', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ })
+
+ it('selectedToolName referencing non-existent tool uses original tool', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-tool-call'])
+ hookProvider.beforeToolCall.mockReturnValue({
+ cancel: false,
+ cancelMessage: undefined,
+ toolUse: undefined,
+ selectedToolName: 'nonexistent_tool',
+ })
+
+ let originalExecuted = false
+ const originalTool = testTool({
+ name: 'original_tool',
+ description: 'original',
+ callback: () => {
+ originalExecuted = true
+ return [{ text: 'original result' }]
+ },
+ })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'toolUseBlock', name: 'original_tool', toolUseId: 'tu-1', input: {} })
+ model.addTurn({ type: 'textBlock', text: 'done' })
+
+ const agent = new Agent({ model, plugins: [new HookProviderBridge()], tools: [originalTool], printer: false })
+ await agent.invoke('use tool')
+
+ expect(originalExecuted).toBe(true)
+ expect(hookProvider.beforeToolCall).toHaveBeenCalled()
+ })
+
+ it('capabilities are read once at init, not per-invocation', async () => {
+ hookProvider.getCapabilities.mockReturnValue(['before-invocation'])
+ hookProvider.beforeInvocation.mockReturnValue({ cancel: false, cancelMessage: undefined })
+
+ const model = new MockMessageModel()
+ model.addTurn({ type: 'textBlock', text: 'first response' })
+ model.addTurn({ type: 'textBlock', text: 'second response' })
+
+ const agent = new Agent({ model, plugins: [new HookProviderBridge()], printer: false })
+ await agent.invoke('first')
+
+ hookProvider.getCapabilities.mockClear()
+ hookProvider.beforeInvocation.mockClear()
+ hookProvider.beforeInvocation.mockReturnValue({ cancel: false, cancelMessage: undefined })
+
+ await agent.invoke('second')
+
+ expect(hookProvider.getCapabilities).toHaveBeenCalledTimes(0)
+ expect(hookProvider.beforeInvocation).toHaveBeenCalledTimes(1)
+ })
+})
diff --git a/strands-wasm/__tests__/lifecycle.test.ts b/strands-wasm/__tests__/lifecycle.test.ts
index 9a1f49a80..b6a77d30e 100644
--- a/strands-wasm/__tests__/lifecycle.test.ts
+++ b/strands-wasm/__tests__/lifecycle.test.ts
@@ -1,5 +1,5 @@
import { describe, it, expect } from 'vitest'
-import { LifecycleBridge } from '../../entry'
+import { LifecycleBridge } from '../entry'
import { Agent, FunctionTool } from '@strands-agents/sdk'
import { MockMessageModel } from '$/fixtures/mock-message-model'
@@ -136,6 +136,18 @@ describe('LifecycleBridge', () => {
const events = bridge.drain()
+ const beforeTools = events.find((e) => e.val.eventType === 'before-tools')
+ expect(beforeTools).toStrictEqual({
+ tag: 'lifecycle',
+ val: { eventType: 'before-tools', toolUse: undefined, toolResult: undefined },
+ })
+
+ const afterTools = events.find((e) => e.val.eventType === 'after-tools')
+ expect(afterTools).toStrictEqual({
+ tag: 'lifecycle',
+ val: { eventType: 'after-tools', toolUse: undefined, toolResult: undefined },
+ })
+
const beforeToolCall = events.find((e) => e.val.eventType === 'before-tool-call')
expect(beforeToolCall).toStrictEqual({
tag: 'lifecycle',
diff --git a/strands-wasm/__tests__/mapping.test.ts b/strands-wasm/__tests__/mapping.test.ts
index a71b19ddd..356ffabcb 100644
--- a/strands-wasm/__tests__/mapping.test.ts
+++ b/strands-wasm/__tests__/mapping.test.ts
@@ -10,7 +10,7 @@ import {
mapToolStreamEvent,
parseInput,
parseSaveLatestStrategy,
-} from '../../entry'
+} from '../entry'
import type { AgentStreamEvent, ModelStreamEvent, StopReason } from '@strands-agents/sdk'
import { ToolStreamEvent, ToolUseBlock, ToolResultBlock, TextBlock, ReasoningBlock } from '@strands-agents/sdk'
diff --git a/strands-wasm/__tests__/stream.test.ts b/strands-wasm/__tests__/stream.test.ts
index 1a20ea9d6..c2ec02513 100644
--- a/strands-wasm/__tests__/stream.test.ts
+++ b/strands-wasm/__tests__/stream.test.ts
@@ -1,5 +1,5 @@
import { describe, it, expect, vi } from 'vitest'
-import { api, LifecycleBridge } from '../../entry'
+import { api, LifecycleBridge } from '../entry'
import { Agent } from '@strands-agents/sdk'
import { MockMessageModel } from '$/fixtures/mock-message-model'
diff --git a/strands-wasm/__tests__/tool-bridge.test.ts b/strands-wasm/__tests__/tool-bridge.test.ts
index a0f0fb92f..77b9f9d4c 100644
--- a/strands-wasm/__tests__/tool-bridge.test.ts
+++ b/strands-wasm/__tests__/tool-bridge.test.ts
@@ -1,5 +1,5 @@
import { describe, it, expect, vi, beforeEach } from 'vitest'
-import { createTools } from '../../entry'
+import { createTools } from '../entry'
import { callTool } from 'strands:agent/tool-provider'
const emptyToolContext = { toolUse: { toolUseId: '' } } as any
diff --git a/strands-wasm/entry.ts b/strands-wasm/entry.ts
index 67e507443..aa6699ecd 100644
--- a/strands-wasm/entry.ts
+++ b/strands-wasm/entry.ts
@@ -12,6 +12,7 @@
///
///
///
+///
import type {
AgentConfig,
@@ -29,6 +30,17 @@ import type {
import { callTool } from 'strands:agent/tool-provider'
import { log as hostLog } from 'strands:agent/host-log'
+import {
+ getCapabilities as hookGetCapabilities,
+ beforeInvocation as hookBeforeInvocation,
+ afterInvocation as hookAfterInvocation,
+ beforeModelCall as hookBeforeModelCall,
+ afterModelCall as hookAfterModelCall,
+ beforeTools as hookBeforeTools,
+ afterTools as hookAfterTools,
+ beforeToolCall as hookBeforeToolCall,
+ afterToolCall as hookAfterToolCall,
+} from 'strands:agent/hook-provider'
import { Agent, FunctionTool, SessionManager, FileStorage } from '@strands-agents/sdk'
import { S3Storage } from '@strands-agents/sdk/session/s3-storage'
import { AnthropicModel } from '@strands-agents/sdk/models/anthropic'
@@ -69,7 +81,10 @@ import {
BeforeInvocationEvent,
BeforeModelCallEvent,
BeforeToolCallEvent,
+ BeforeToolsEvent,
+ AfterToolsEvent,
MessageAddedEvent,
+ ToolResultBlock,
} from '@strands-agents/sdk'
// All log calls go through `hostLog` (the WIT import). The host can
@@ -432,6 +447,8 @@ class LifecycleBridge implements Plugin {
agent.addHook(AfterInvocationEvent, () => this.push('after-invocation'))
agent.addHook(BeforeModelCallEvent, () => this.push('before-model-call'))
agent.addHook(AfterModelCallEvent, () => this.push('after-model-call'))
+ agent.addHook(BeforeToolsEvent, () => this.push('before-tools'))
+ agent.addHook(AfterToolsEvent, () => this.push('after-tools'))
agent.addHook(MessageAddedEvent, () => this.push('message-added'))
agent.addHook(BeforeToolCallEvent, (event) => {
@@ -448,6 +465,117 @@ class LifecycleBridge implements Plugin {
}
}
+/** Bridges TS SDK hooks to WIT hook-provider imports for bidirectional host control. */
+class HookProviderBridge implements Plugin {
+ readonly name = 'strands:hook-provider-bridge'
+
+ /** Apply a cancel decision from a hook-provider response to a cancellable event. */
+ private applyCancel(
+ event: { cancel?: boolean | string },
+ decision: { cancel: boolean; cancelMessage?: string }
+ ): void {
+ if (decision.cancel) {
+ event.cancel = decision.cancelMessage ?? true
+ }
+ }
+
+ initAgent(agent: LocalAgent): void {
+ const caps = new Set(hookGetCapabilities())
+
+ if (caps.has('before-invocation')) {
+ agent.addHook(BeforeInvocationEvent, (event) => {
+ const decision = hookBeforeInvocation()
+ this.applyCancel(event, decision)
+ })
+ }
+
+ if (caps.has('after-invocation')) {
+ agent.addHook(AfterInvocationEvent, (event) => {
+ const decision = hookAfterInvocation()
+ if (decision.resume) {
+ try {
+ event.resume = JSON.parse(decision.resume)
+ } catch {
+ glog('warn', `hook_name= | invalid JSON in resume field, ignoring`)
+ }
+ }
+ })
+ }
+
+ if (caps.has('before-model-call')) {
+ agent.addHook(BeforeModelCallEvent, (event) => {
+ const decision = hookBeforeModelCall(event.projectedInputTokens)
+ this.applyCancel(event, decision)
+ })
+ }
+
+ if (caps.has('after-model-call')) {
+ agent.addHook(AfterModelCallEvent, (event) => {
+ const decision = hookAfterModelCall(event.stopData?.stopReason, undefined, event.error?.message)
+ if (decision.retry) {
+ event.retry = true
+ }
+ })
+ }
+
+ if (caps.has('before-tools')) {
+ agent.addHook(BeforeToolsEvent, (event) => {
+ const decision = hookBeforeTools(JSON.stringify(event.message))
+ this.applyCancel(event, decision)
+ })
+ }
+
+ if (caps.has('after-tools')) {
+ agent.addHook(AfterToolsEvent, () => void hookAfterTools())
+ }
+
+ if (caps.has('before-tool-call')) {
+ agent.addHook(BeforeToolCallEvent, (event) => {
+ const decision = hookBeforeToolCall(JSON.stringify(event.toolUse))
+ this.applyCancel(event, decision)
+ if (decision.toolUse) {
+ try {
+ const rewritten = JSON.parse(decision.toolUse)
+ if (rewritten.name !== undefined) event.toolUse.name = rewritten.name
+ if (rewritten.toolUseId !== undefined) event.toolUse.toolUseId = rewritten.toolUseId
+ if (rewritten.input !== undefined) event.toolUse.input = rewritten.input
+ } catch {
+ glog('warn', `hook_name= | invalid JSON in toolUse field, ignoring`)
+ }
+ }
+ if (decision.selectedToolName) {
+ const tool = agent.toolRegistry.get(decision.selectedToolName)
+ if (tool) {
+ event.selectedTool = tool
+ } else {
+ glog(
+ 'warn',
+ `hook_name=, tool_name=<${decision.selectedToolName}> | hook selected unknown tool, ignoring`
+ )
+ }
+ }
+ })
+ }
+
+ if (caps.has('after-tool-call')) {
+ agent.addHook(AfterToolCallEvent, (event) => {
+ const error = event.error?.message
+ const decision = hookAfterToolCall(JSON.stringify(event.toolUse), JSON.stringify(event.result), error)
+ if (decision.retry) {
+ event.retry = true
+ }
+ if (decision.result) {
+ try {
+ event.result = ToolResultBlock.fromJSON(JSON.parse(decision.result))
+ } catch {
+ glog('warn', `hook_name= | invalid JSON in result field, ignoring`)
+ }
+ }
+ })
+ }
+ }
+}
+
/** Parse user input — JSON arrays pass through, plain strings stay as-is. */
function parseInput(input: string): InvokeArgs {
try {
@@ -554,7 +682,7 @@ class AgentImpl {
this.sessionManager = createSessionManager(config)
const conversationManager = createConversationManager(config)
- const plugins: Plugin[] = [this.lifecycleBridge]
+ const plugins: Plugin[] = [this.lifecycleBridge, new HookProviderBridge()]
this.agent = new Agent({
model,
@@ -726,6 +854,7 @@ export {
parseInput,
createTools,
LifecycleBridge,
+ HookProviderBridge,
parseSaveLatestStrategy,
createToolChoiceProxy,
}
diff --git a/strands-wasm/vitest.config.ts b/strands-wasm/vitest.config.ts
index a22d5b46e..a3de572db 100644
--- a/strands-wasm/vitest.config.ts
+++ b/strands-wasm/vitest.config.ts
@@ -13,6 +13,7 @@ export default defineConfig({
alias: {
'strands:agent/tool-provider': resolve(__dirname, '__fixtures__/tool-provider.ts'),
'strands:agent/host-log': resolve(__dirname, '__fixtures__/host-log.ts'),
+ 'strands:agent/hook-provider': resolve(__dirname, '__fixtures__/hook-provider.ts'),
'$/fixtures': resolve(__dirname, '../strands-ts/src/__fixtures__'),
},
},
diff --git a/wit/agent.wit b/wit/agent.wit
index d6fa57208..77a169a02 100644
--- a/wit/agent.wit
+++ b/wit/agent.wit
@@ -70,6 +70,8 @@ interface types {
after-invocation,
before-model-call,
after-model-call,
+ before-tools,
+ after-tools,
before-tool-call,
after-tool-call,
message-added,
@@ -252,6 +254,75 @@ interface host-log {
log: func(entry: log-entry);
}
+/// Bidirectional hook control from host to guest.
+/// The guest calls these at lifecycle points and receives decisions
+/// that can cancel, retry, redirect, or mutate the operation.
+interface hook-provider {
+ /// Capability negotiation. The guest calls this once during init.
+ /// The host returns which hooks it supports. For unsupported hooks,
+ /// the guest skips the WIT import call and uses the default decision.
+ enum hook-capability {
+ before-invocation,
+ after-invocation,
+ before-model-call,
+ after-model-call,
+ before-tools,
+ after-tools,
+ before-tool-call,
+ after-tool-call,
+ }
+
+ get-capabilities: func() -> list;
+
+ /// Invariant: cancel-message is only meaningful when cancel is true.
+ record invocation-decision {
+ cancel: bool,
+ cancel-message: option,
+ }
+
+ record after-invocation-decision {
+ resume: option,
+ }
+
+ record model-call-decision {
+ cancel: bool,
+ cancel-message: option,
+ }
+
+ record after-model-call-decision {
+ retry: bool,
+ }
+
+ /// Dedicated decision type for before-tools (cancel all tool calls in a batch).
+ record tools-decision {
+ cancel: bool,
+ cancel-message: option,
+ }
+
+ record tool-call-decision {
+ cancel: bool,
+ cancel-message: option,
+ tool-use: option,
+ selected-tool-name: option,
+ }
+
+ record after-tool-call-decision {
+ retry: bool,
+ %result: option,
+ }
+
+ record after-tools-decision {}
+
+ before-invocation: func() -> invocation-decision;
+ after-invocation: func() -> after-invocation-decision;
+ before-model-call: func(projected-input-tokens: option) -> model-call-decision;
+ after-model-call: func(stop-reason: option, stop-data: option, error: option) -> after-model-call-decision;
+ before-tools: func(message: string) -> tools-decision;
+ after-tools: func() -> after-tools-decision;
+ before-tool-call: func(tool-use: string) -> tool-call-decision;
+ after-tool-call: func(tool-use: string, %result: string, error: option) -> after-tool-call-decision;
+}
+
/// The main API exported by the WASM guest.
interface api {
use types.{agent-config, stream-event, stream-args, respond-args, set-messages-args};
@@ -276,6 +347,7 @@ interface api {
world agent {
import tool-provider;
import host-log;
+ import hook-provider;
export api;
}