Skip to content
This repository was archived by the owner on Jun 3, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 113 additions & 17 deletions strands-py/strands/_wasm_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ def _run_sync(coro: typing.Coroutine) -> typing.Any:
# Already inside a running loop — run in a fresh thread to avoid nesting
result = [None]
exc = [None]

def _target():
try:
result[0] = asyncio.run(coro)
except Exception as e:
exc[0] = e

t = threading.Thread(target=_target)
t.start()
t.join()
Expand Down Expand Up @@ -169,6 +171,7 @@ def _get_engine_and_component() -> tuple[Engine, Component]:
# Record / Variant builders (Python → WIT kebab-case)
# ---------------------------------------------------------------------------


def _rec(**kwargs: typing.Any) -> Record:
"""Build a wasmtime-py Record with the given kebab-case fields."""
r = Record.__new__(Record)
Expand All @@ -178,7 +181,9 @@ def _rec(**kwargs: typing.Any) -> Record:


def _build_tool_spec(ts: ToolSpec) -> Record:
return _rec(name=ts.name, description=ts.description, **{"input-schema": ts.input_schema})
return _rec(
name=ts.name, description=ts.description, **{"input-schema": ts.input_schema}
)


def _build_model_config_variant(cfg: ModelConfigInput) -> Variant:
Expand Down Expand Up @@ -266,7 +271,9 @@ def _build_conversation_manager_variant(
"should-truncate-results": False,
"summary-ratio": config.get("summary_ratio"),
"preserve-recent-messages": config.get("preserve_recent_messages"),
"summarization-system-prompt": config.get("summarization_system_prompt"),
"summarization-system-prompt": config.get(
"summarization_system_prompt"
),
"summarization-model-config": config.get("summarization_model_config"),
},
)
Expand Down Expand Up @@ -319,6 +326,7 @@ def _build_stream_args(
# Variant → flat StreamEvent converters (WIT → Python types)
# ---------------------------------------------------------------------------


def _opt_attr(rec: typing.Any, name: str) -> typing.Any:
"""Read an optional attribute from a wasmtime Record (kebab-case)."""
return getattr(rec, name, None) if rec is not None else None
Expand Down Expand Up @@ -413,6 +421,7 @@ def _convert_stream_event(v: Variant) -> StreamEvent:
# AWS credential injection
# ---------------------------------------------------------------------------


def _resolve_aws_credentials() -> tuple[str, str, str | None] | None:
key_id = os.environ.get("AWS_ACCESS_KEY_ID")
secret = os.environ.get("AWS_SECRET_ACCESS_KEY")
Expand Down Expand Up @@ -485,7 +494,10 @@ def _inject_aws_credentials_default() -> Variant | None:
# Import callback factories
# ---------------------------------------------------------------------------

def _make_call_tool_fn(dispatcher: ToolDispatcherBase | None) -> typing.Callable[..., typing.Any]:

def _make_call_tool_fn(
dispatcher: ToolDispatcherBase | None,
) -> typing.Callable[..., typing.Any]:
def call_tool(store_ctx: typing.Any, args: typing.Any) -> Variant:
name = getattr(args, "name")
input_json = getattr(args, "input")
Expand All @@ -497,10 +509,13 @@ def call_tool(store_ctx: typing.Any, args: typing.Any) -> Variant:
return Variant("ok", result)
except Exception as exc:
return Variant("err", str(exc))

return call_tool


def _make_call_tools_fn(dispatcher: ToolDispatcherBase | None) -> typing.Callable[..., typing.Any]:
def _make_call_tools_fn(
dispatcher: ToolDispatcherBase | None,
) -> typing.Callable[..., typing.Any]:
def call_tools(store_ctx: typing.Any, args: typing.Any) -> list[Variant]:
calls = getattr(args, "calls")
results: list[Variant] = []
Expand All @@ -517,6 +532,7 @@ def call_tools(store_ctx: typing.Any, args: typing.Any) -> list[Variant]:
except Exception as exc:
results.append(Variant("err", str(exc)))
return results

return call_tools


Expand All @@ -529,18 +545,95 @@ def log_fn(store_ctx: typing.Any, entry: typing.Any) -> None:
handler.log(level, message, context)
else:
logger = logging.getLogger("strands.wasm")
py_level = {"error": 40, "warn": 30, "info": 20, "debug": 10, "trace": 10}.get(
level, 20
)
py_level = {
"error": 40,
"warn": 30,
"info": 20,
"debug": 10,
"trace": 10,
}.get(level, 20)
msg = f"{message} | {context}" if context else message
logger.log(py_level, msg)

return log_fn


# ---------------------------------------------------------------------------
# Hook-provider no-op callbacks (return zero-value decisions)
# ---------------------------------------------------------------------------


def _hook_get_capabilities(store_ctx: typing.Any) -> list[str]:
return []


def _cancel_decision() -> Record:
"""Default no-op cancellable decision."""
return _rec(cancel=False, **{"cancel-message": None})


def _hook_before_invocation(store_ctx: typing.Any) -> Record:
return _cancel_decision()


def _hook_after_invocation(store_ctx: typing.Any) -> Record:
return _rec(resume=None)


def _hook_before_model_call(
store_ctx: typing.Any, projected_input_tokens: typing.Optional[int]
) -> Record:
return _cancel_decision()


def _hook_after_model_call(
store_ctx: typing.Any,
stop_reason: typing.Optional[str],
stop_data: typing.Optional[str],
error: typing.Optional[str],
) -> Record:
return _rec(retry=False)


def _hook_before_tools(store_ctx: typing.Any, message: str) -> Record:
return _cancel_decision()


def _hook_after_tools(store_ctx: typing.Any) -> Record:
return _rec()


def _hook_before_tool_call(store_ctx: typing.Any, tool_use: str) -> Record:
return _rec(
cancel=False,
**{"cancel-message": None, "tool-use": None, "selected-tool-name": None},
)


def _hook_after_tool_call(
store_ctx: typing.Any, tool_use: str, result: str, error: typing.Optional[str]
) -> Record:
return _rec(retry=False, result=None)


_HOOK_PROVIDER_FUNCS: list[tuple[str, typing.Callable]] = [
("get-capabilities", _hook_get_capabilities),
("before-invocation", _hook_before_invocation),
("after-invocation", _hook_after_invocation),
("before-model-call", _hook_before_model_call),
("after-model-call", _hook_after_model_call),
("before-tools", _hook_before_tools),
("after-tools", _hook_after_tools),
("before-tool-call", _hook_before_tool_call),
("after-tool-call", _hook_after_tool_call),
]


# ---------------------------------------------------------------------------
# WasmAgent — drop-in replacement for the former native Agent class
# ---------------------------------------------------------------------------


class WasmAgent:
"""WASM-hosted agent with the same API as the former native ``Agent``."""

Expand Down Expand Up @@ -568,6 +661,9 @@ def __init__(
tp.add_func("call-tools", _make_call_tools_fn(tool_dispatcher))
with root.add_instance("strands:agent/host-log") as hl:
hl.add_func("log", _make_log_fn(log_handler))
with root.add_instance("strands:agent/hook-provider") as hp:
for name, fn in _HOOK_PROVIDER_FUNCS:
hp.add_func(name, fn)

# --- store ---
store = Store(engine)
Expand All @@ -584,7 +680,13 @@ def __init__(
self._component = component

# --- instantiate + construct agent (async, run synchronously) ---
agent_config = _build_agent_config(model, system_prompt, system_prompt_blocks, tools, conversation_manager_config)
agent_config = _build_agent_config(
model,
system_prompt,
system_prompt_blocks,
tools,
conversation_manager_config,
)
_run_sync(self._init_async(linker, store, component, agent_config))

async def _init_async(
Expand Down Expand Up @@ -622,9 +724,7 @@ def _fn(name: str) -> Func:

async def start_stream(self, input_text: str) -> typing.Any:
args = _build_stream_args(input_text, None, None)
return await self._generate_fn.call_async(
self._store, self._agent_handle, args
)
return await self._generate_fn.call_async(self._store, self._agent_handle, args)

async def start_stream_with_options(
self,
Expand All @@ -633,13 +733,9 @@ async def start_stream_with_options(
tool_choice: str | None,
) -> typing.Any:
args = _build_stream_args(input_text, tools, tool_choice)
return await self._generate_fn.call_async(
self._store, self._agent_handle, args
)
return await self._generate_fn.call_async(self._store, self._agent_handle, args)

async def next_events(
self, stream_handle: typing.Any
) -> list[StreamEvent] | None:
async def next_events(self, stream_handle: typing.Any) -> list[StreamEvent] | None:
raw = await self._read_next_fn.call_async(self._store, stream_handle)
if raw is None:
return None
Expand Down
Empty file added strands-py/tests/__init__.py
Empty file.
113 changes: 113 additions & 0 deletions strands-py/tests/test_hook_provider_stubs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Unit tests for hook-provider no-op stubs in _wasm_host.py."""

from strands._wasm_host import (
_cancel_decision,
_hook_get_capabilities,
_hook_before_invocation,
_hook_after_invocation,
_hook_before_model_call,
_hook_after_model_call,
_hook_before_tools,
_hook_after_tools,
_hook_before_tool_call,
_hook_after_tool_call,
_HOOK_PROVIDER_FUNCS,
)


class TestCancelDecision:
def test_produces_correct_record_shape(self):
result = _cancel_decision()
assert result.cancel is False
assert getattr(result, "cancel-message") is None


class TestHookBeforeToolCall:
def test_returns_all_four_decision_fields(self):
result = _hook_before_tool_call(None, "{}")
assert result.cancel is False
assert getattr(result, "cancel-message") is None
assert getattr(result, "tool-use") is None
assert getattr(result, "selected-tool-name") is None


class TestHookAfterToolCall:
def test_returns_retry_and_result_fields(self):
result = _hook_after_tool_call(None, "{}", "{}", None)
assert result.retry is False
assert result.result is None


class TestHookAfterModelCall:
def test_accepts_all_none_params(self):
result = _hook_after_model_call(None, None, None, None)
assert result.retry is False

def test_accepts_string_params(self):
result = _hook_after_model_call(
None, "endTurn", '{"reason":"end-turn"}', "something failed"
)
assert result.retry is False


class TestHookProviderFuncsList:
def test_completeness(self):
assert len(_HOOK_PROVIDER_FUNCS) == 9

def test_names(self):
names = [name for name, _ in _HOOK_PROVIDER_FUNCS]
assert names == [
"get-capabilities",
"before-invocation",
"after-invocation",
"before-model-call",
"after-model-call",
"before-tools",
"after-tools",
"before-tool-call",
"after-tool-call",
]


class TestHookGetCapabilities:
def test_returns_empty_list(self):
result = _hook_get_capabilities(None)
assert result == []


class TestHookBeforeInvocation:
def test_returns_cancel_decision(self):
result = _hook_before_invocation(None)
assert result.cancel is False
assert getattr(result, "cancel-message") is None


class TestHookAfterInvocation:
def test_returns_resume_none(self):
result = _hook_after_invocation(None)
assert result.resume is None


class TestHookBeforeModelCall:
def test_returns_cancel_decision_with_none(self):
result = _hook_before_model_call(None, None)
assert result.cancel is False
assert getattr(result, "cancel-message") is None

def test_returns_cancel_decision_with_int(self):
result = _hook_before_model_call(None, 42)
assert result.cancel is False
assert getattr(result, "cancel-message") is None


class TestHookBeforeTools:
def test_returns_cancel_decision(self):
result = _hook_before_tools(None, '{"role":"assistant","content":[]}')
assert result.cancel is False
assert getattr(result, "cancel-message") is None


class TestHookAfterTools:
def test_returns_empty_record(self):
result = _hook_after_tools(None)
assert result.__dict__ == {}
28 changes: 28 additions & 0 deletions strands-wasm/__fixtures__/hook-provider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { vi } from 'vitest'
import { FunctionTool, type FunctionToolCallback } from '@strands-agents/sdk'
export const getCapabilities = vi.fn().mockReturnValue([])
export const beforeInvocation = vi.fn().mockReturnValue({ cancel: false, cancelMessage: undefined })
export const afterInvocation = vi.fn().mockReturnValue({ resume: undefined })
export const beforeModelCall = vi.fn().mockReturnValue({ cancel: false, cancelMessage: undefined })
export const afterModelCall = vi.fn().mockReturnValue({ retry: false })
export const beforeTools = vi.fn().mockReturnValue({ cancel: false, cancelMessage: undefined })
export const afterTools = vi.fn().mockReturnValue({})
export const beforeToolCall = vi.fn().mockReturnValue({
cancel: false,
cancelMessage: undefined,
toolUse: undefined,
selectedToolName: undefined,
})
export const afterToolCall = vi.fn().mockReturnValue({ retry: false, result: undefined })

/** Factory for test tools with sensible defaults. */
export function testTool(
overrides: { name?: string; description?: string; callback?: FunctionToolCallback } = {}
): FunctionTool {
return new FunctionTool({
name: overrides.name ?? 'test_tool',
description: overrides.description ?? 'test',
inputSchema: { type: 'object', properties: {} },
callback: overrides.callback ?? (() => [{ text: 'ok' }]),
})
}
Loading
Loading