Skip to content

Commit 70b0989

Browse files
feat(hooks): accept callable hook callbacks in Agent constructor (#1992)
Co-authored-by: agent-of-mkmeral <agent-of-mkmeral@users.noreply.github.com> Co-authored-by: agent-of-mkmeral <217235299+strands-agent@users.noreply.github.com>
1 parent cd5da4f commit 70b0989

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

src/strands/agent/agent.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(
132132
description: str | None = None,
133133
state: AgentState | dict | None = None,
134134
plugins: list[Plugin] | None = None,
135-
hooks: list[HookProvider] | None = None,
135+
hooks: list[HookProvider | HookCallback] | None = None,
136136
session_manager: SessionManager | None = None,
137137
structured_output_prompt: str | None = None,
138138
tool_executor: ToolExecutor | None = None,
@@ -187,7 +187,8 @@ def __init__(
187187
Plugins are initialized with the agent instance after construction and can register hooks,
188188
modify agent attributes, or perform other setup tasks.
189189
Defaults to None.
190-
hooks: hooks to be added to the agent hook registry
190+
hooks: Hooks to be added to the agent hook registry. Accepts HookProvider instances
191+
or plain callable hook callbacks (functions with typed event parameters).
191192
Defaults to None.
192193
session_manager: Manager for handling agent sessions including conversation history and state.
193194
If provided, enables session-based persistence and state management.
@@ -341,7 +342,14 @@ def __init__(
341342

342343
if hooks:
343344
for hook in hooks:
344-
self.hooks.add_hook(hook)
345+
if isinstance(hook, HookProvider):
346+
self.hooks.add_hook(hook)
347+
elif callable(hook):
348+
self.hooks.add_callback(None, hook)
349+
else:
350+
raise ValueError(
351+
f"Invalid hook: {hook!r}. Must be a HookProvider instance or a callable hook callback."
352+
)
345353

346354
# Register built-in plugins
347355
self._plugin_registry.add_and_init(_ModelPlugin())

tests/strands/agent/test_agent_hooks.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,3 +1021,55 @@ def interrupt_tool(event: BeforeToolCallEvent):
10211021
assert result.stop_reason == "end_turn"
10221022
assert result.message["content"][0]["text"] == "Final response"
10231023
assert agent._interrupt_state.activated is False
1024+
1025+
1026+
def test_hooks_param_accepts_callable():
1027+
"""Verify that a plain callable can be passed via hooks parameter."""
1028+
events_received = []
1029+
1030+
def my_callback(event: AgentInitializedEvent) -> None:
1031+
events_received.append(event)
1032+
1033+
agent = Agent(hooks=[my_callback], callback_handler=None)
1034+
1035+
assert len(events_received) == 1
1036+
assert isinstance(events_received[0], AgentInitializedEvent)
1037+
assert events_received[0].agent is agent
1038+
1039+
1040+
def test_hooks_param_accepts_mixed_list():
1041+
"""Verify that a mix of HookProviders and callables can be passed."""
1042+
callback_events = []
1043+
1044+
def my_callback(event: AgentInitializedEvent) -> None:
1045+
callback_events.append(event)
1046+
1047+
provider = MockHookProvider(event_types=[AgentInitializedEvent])
1048+
1049+
agent = Agent(hooks=[provider, my_callback], callback_handler=None)
1050+
1051+
assert len(callback_events) == 1
1052+
assert callback_events[0].agent is agent
1053+
length, _ = provider.get_events()
1054+
assert length == 1
1055+
1056+
1057+
def test_hooks_param_invalid_hook_raises_error():
1058+
"""Verify that passing an invalid hook raises ValueError."""
1059+
with pytest.raises(ValueError, match="Invalid hook"):
1060+
Agent(hooks=["not_a_hook"], callback_handler=None) # type: ignore
1061+
1062+
1063+
def test_hooks_param_callable_invoked_during_lifecycle():
1064+
"""Verify callable hooks fire during agent lifecycle."""
1065+
before_events = []
1066+
1067+
def on_before(event: BeforeInvocationEvent) -> None:
1068+
before_events.append(event)
1069+
1070+
mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Hello!"}]}])
1071+
agent = Agent(model=mock_model, hooks=[on_before], callback_handler=None)
1072+
agent("test")
1073+
1074+
assert len(before_events) == 1
1075+
assert isinstance(before_events[0], BeforeInvocationEvent)

0 commit comments

Comments
 (0)