Skip to content

Commit 6dd249d

Browse files
authored
feat(strands-py): add optional hook order (#2559)
1 parent ea48c58 commit 6dd249d

13 files changed

Lines changed: 173 additions & 44 deletions

File tree

site/src/content/docs/user-guide/concepts/agents/hooks.mdx

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,10 +392,33 @@ Most event properties are read-only to prevent unintended modifications. However
392392
<Tabs>
393393
<Tab label="Python">
394394

395-
After event callbacks run in reverse registration order for cleanup symmetry:
395+
By default, After event callbacks run in reverse registration order for cleanup symmetry. You can override this with explicit priority using the `order` option — lower values run first.
396+
397+
The SDK exports convenience presets that mark where the SDK's own hooks run, so you can position yours relative to them:
398+
399+
- `HookOrder.SDK_FIRST` (-100) — where the SDK's earliest hooks run
400+
- `HookOrder.DEFAULT` (0) — implicit when no order is specified
401+
- `HookOrder.SDK_LAST` (100) — where the SDK's latest hooks run
402+
403+
These are not enforced bounds — any numeric value works. Use values beyond them (e.g. `SDK_FIRST - 1`) to run before or after the SDK's hooks, or `float('-inf')`/`float('inf')` for guaranteed absolute ordering.
404+
405+
```python
406+
from strands import Agent
407+
from strands.hooks import BeforeModelCallEvent, HookOrder
408+
409+
agent = Agent()
410+
411+
def early_hook(event: BeforeModelCallEvent) -> None:
412+
print("I run first")
396413

397-
- **Before**: A, B, C (registration order)
398-
- **After**: C, B, A (reverse registration order)
414+
def late_hook(event: BeforeModelCallEvent) -> None:
415+
print("I run last")
416+
417+
agent.add_hook(early_hook, order=HookOrder.SDK_FIRST)
418+
agent.add_hook(late_hook, order=HookOrder.SDK_LAST)
419+
```
420+
421+
Within the same order group, Before events preserve registration order and After events reverse it.
399422

400423
</Tab>
401424
<Tab label="TypeScript">

strands-py/src/strands/agent/agent.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
AgentInitializedEvent,
4747
BeforeInvocationEvent,
4848
HookCallback,
49+
HookOrder,
4950
HookProvider,
5051
HookRegistry,
5152
MessageAddedEvent,
@@ -725,7 +726,11 @@ def cleanup(self) -> None:
725726
self.tool_registry.cleanup()
726727

727728
def add_hook(
728-
self, callback: HookCallback[TEvent], event_type: type[TEvent] | list[type[TEvent]] | None = None
729+
self,
730+
callback: HookCallback[TEvent],
731+
event_type: type[TEvent] | list[type[TEvent]] | None = None,
732+
*,
733+
order: float = HookOrder.DEFAULT,
729734
) -> None:
730735
"""Register a callback function for a specific event type.
731736
@@ -745,6 +750,8 @@ def add_hook(
745750
Can be a single type, a list of types, or None to infer from
746751
the callback's first parameter type hint. If a list is provided,
747752
the callback is registered for each type in the list.
753+
order: Execution priority. Lower values execute first.
754+
Use HookOrder.SDK_FIRST (-100), HookOrder.DEFAULT (0), or HookOrder.SDK_LAST (100).
748755
749756
Raises:
750757
ValueError: If event_type is not provided and cannot be inferred from
@@ -776,7 +783,7 @@ def multi_handler(event) -> None:
776783
Docs:
777784
https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/
778785
"""
779-
self.hooks.add_callback(event_type, callback)
786+
self.hooks.add_callback(event_type, callback, order=order)
780787

781788
def __del__(self) -> None:
782789
"""Clean up resources when agent is garbage collected."""

strands-py/src/strands/hooks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def log_end(self, event: AfterInvocationEvent) -> None:
4545
MessageAddedEvent,
4646
MultiAgentInitializedEvent,
4747
)
48-
from .registry import BaseHookEvent, HookCallback, HookEvent, HookProvider, HookRegistry
48+
from .registry import BaseHookEvent, HookCallback, HookEvent, HookOrder, HookProvider, HookRegistry
4949

5050
__all__ = [
5151
"AgentInitializedEvent",
@@ -57,6 +57,7 @@ def log_end(self, event: AfterInvocationEvent) -> None:
5757
"AfterInvocationEvent",
5858
"MessageAddedEvent",
5959
"HookEvent",
60+
"HookOrder",
6061
"HookProvider",
6162
"HookCallback",
6263
"HookRegistry",

strands-py/src/strands/hooks/registry.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
via hook provider objects.
88
"""
99

10+
import bisect
1011
import inspect
1112
import logging
1213
from collections.abc import Awaitable, Generator
1314
from dataclasses import dataclass
15+
from itertools import groupby
1416
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable
1517

1618
from ..interrupt import Interrupt, InterruptException
@@ -22,6 +24,25 @@
2224
logger = logging.getLogger(__name__)
2325

2426

27+
class HookOrder:
28+
"""Named constants for hook execution priority.
29+
30+
Lower values execute first. Hooks with the same order preserve registration order.
31+
"""
32+
33+
SDK_FIRST: int = -100
34+
DEFAULT: int = 0
35+
SDK_LAST: int = 100
36+
37+
38+
@dataclass
39+
class _CallbackEntry:
40+
"""Internal entry pairing a callback with its execution order."""
41+
42+
callback: "HookCallback"
43+
order: float
44+
45+
2546
@dataclass
2647
class BaseHookEvent:
2748
"""Base class for all hook events."""
@@ -156,12 +177,14 @@ class HookRegistry:
156177

157178
def __init__(self) -> None:
158179
"""Initialize an empty hook registry."""
159-
self._registered_callbacks: dict[type, list[HookCallback]] = {}
180+
self._registered_callbacks: dict[type, list[_CallbackEntry]] = {}
160181

161182
def add_callback(
162183
self,
163184
event_type: type[TEvent] | list[type[TEvent]] | None,
164185
callback: HookCallback[TEvent],
186+
*,
187+
order: float = HookOrder.DEFAULT,
165188
) -> None:
166189
"""Register a callback function for a specific event type.
167190
@@ -176,6 +199,7 @@ def add_callback(
176199
event_type: The lifecycle event type(s) this callback should handle.
177200
Can be a single type, a list of types, or None to infer from type hints.
178201
callback: The callback function to invoke when events of this type occur.
202+
order: Execution priority. Lower values execute first.
179203
180204
Raises:
181205
ValueError: If event_type is not provided and cannot be inferred from
@@ -227,8 +251,9 @@ def multi_handler(event):
227251
if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback):
228252
raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback")
229253

230-
callbacks = self._registered_callbacks.setdefault(resolved_event_type, [])
231-
callbacks.append(callback)
254+
entries = self._registered_callbacks.setdefault(resolved_event_type, [])
255+
entry = _CallbackEntry(callback=callback, order=order)
256+
bisect.insort(entries, entry, key=lambda e: e.order)
232257

233258
def _validate_event_type_list(self, event_types: list[type[TEvent]]) -> list[type[TEvent]]:
234259
"""Validate that all types in a list are valid BaseHookEvent subclasses.
@@ -381,9 +406,12 @@ def has_callbacks(self) -> bool:
381406
def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], None, None]:
382407
"""Get callbacks registered for the given event in the appropriate order.
383408
384-
This method returns callbacks in registration order for normal events,
385-
or reverse registration order for events that have should_reverse_callbacks=True.
386-
This enables proper cleanup ordering for teardown events.
409+
For normal events, callbacks are returned in order priority (lower first),
410+
with registration order preserved within the same priority.
411+
412+
For reversed events (should_reverse_callbacks=True), order priority still
413+
applies (lower first), but within the same priority group, registration
414+
order is reversed.
387415
388416
Args:
389417
event: The event to get callbacks for.
@@ -400,8 +428,11 @@ def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], No
400428
"""
401429
event_type = type(event)
402430

403-
callbacks = self._registered_callbacks.get(event_type, [])
431+
entries = self._registered_callbacks.get(event_type, [])
404432
if event.should_reverse_callbacks:
405-
yield from reversed(callbacks)
433+
for _order, group in groupby(entries, key=lambda e: e.order):
434+
for entry in reversed(list(group)):
435+
yield entry.callback
406436
else:
407-
yield from callbacks
437+
for entry in entries:
438+
yield entry.callback

strands-py/src/strands/multiagent/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .._async import run_async
1515
from ..agent import AgentResult
16-
from ..hooks.registry import HookCallback
16+
from ..hooks.registry import HookCallback, HookOrder
1717
from ..interrupt import Interrupt
1818
from ..types.event_loop import Metrics, Usage
1919
from ..types.multiagent import MultiAgentInput
@@ -255,7 +255,9 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
255255
"""Restore orchestrator state from a session dict."""
256256
raise NotImplementedError
257257

258-
def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None:
258+
def add_hook(
259+
self, callback: HookCallback, event_type: type | list[type] | None = None, *, order: float = HookOrder.DEFAULT
260+
) -> None:
259261
"""Register a hook callback with the orchestrator.
260262
261263
Subclasses that support hooks should override this method to register
@@ -266,6 +268,7 @@ def add_hook(self, callback: HookCallback, event_type: type | list[type] | None
266268
event_type: The class type(s) of events this callback should handle.
267269
Can be a single type, a list of types, or None to infer from
268270
the callback's first parameter type hint.
271+
order: Execution priority. Lower values execute first.
269272
"""
270273
raise NotImplementedError(f"{type(self).__name__} must implement add_hook() to support plugins")
271274

strands-py/src/strands/multiagent/graph.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
BeforeNodeCallEvent,
3636
MultiAgentInitializedEvent,
3737
)
38-
from ..hooks.registry import HookCallback, HookProvider, HookRegistry
38+
from ..hooks.registry import HookCallback, HookOrder, HookProvider, HookRegistry
3939
from ..interrupt import Interrupt, _InterruptState
4040
from ..plugins.multiagent_plugin import MultiAgentPlugin
4141
from ..plugins.multiagent_registry import _MultiAgentPluginRegistry
@@ -495,16 +495,19 @@ def __init__(
495495

496496
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
497497

498-
def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None:
498+
def add_hook(
499+
self, callback: HookCallback, event_type: type | list[type] | None = None, *, order: float = HookOrder.DEFAULT
500+
) -> None:
499501
"""Register a hook callback with the graph.
500502
501503
Args:
502504
callback: The callback function to invoke when events of this type occur.
503505
event_type: The class type(s) of events this callback should handle.
504506
Can be a single type, a list of types, or None to infer from
505507
the callback's first parameter type hint.
508+
order: Execution priority. Lower values execute first.
506509
"""
507-
self.hooks.add_callback(event_type, callback)
510+
self.hooks.add_callback(event_type, callback, order=order)
508511

509512
def __call__(
510513
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any

strands-py/src/strands/multiagent/swarm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
BeforeNodeCallEvent,
3636
MultiAgentInitializedEvent,
3737
)
38-
from ..hooks.registry import HookCallback, HookProvider, HookRegistry
38+
from ..hooks.registry import HookCallback, HookOrder, HookProvider, HookRegistry
3939
from ..interrupt import Interrupt, _InterruptState
4040
from ..plugins.multiagent_plugin import MultiAgentPlugin
4141
from ..plugins.multiagent_registry import _MultiAgentPluginRegistry
@@ -315,16 +315,19 @@ def __init__(
315315
self._inject_swarm_tools()
316316
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
317317

318-
def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None:
318+
def add_hook(
319+
self, callback: HookCallback, event_type: type | list[type] | None = None, *, order: float = HookOrder.DEFAULT
320+
) -> None:
319321
"""Register a hook callback with the swarm.
320322
321323
Args:
322324
callback: The callback function to invoke when events of this type occur.
323325
event_type: The class type(s) of events this callback should handle.
324326
Can be a single type, a list of types, or None to infer from
325327
the callback's first parameter type hint.
328+
order: Execution priority. Lower values execute first.
326329
"""
327-
self.hooks.add_callback(event_type, callback)
330+
self.hooks.add_callback(event_type, callback, order=order)
328331

329332
def __call__(
330333
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any

strands-py/tests/strands/agent/hooks/test_hook_registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_add_callback(hook_registry, normal_event):
5858
hook_registry.add_callback(NormalTestEvent, callback)
5959

6060
assert NormalTestEvent in hook_registry._registered_callbacks
61-
assert callback in hook_registry._registered_callbacks[NormalTestEvent]
61+
assert any(e.callback is callback for e in hook_registry._registered_callbacks[NormalTestEvent])
6262

6363

6464
def test_add_multiple_callbacks_same_event(hook_registry, normal_event):
@@ -70,8 +70,8 @@ def test_add_multiple_callbacks_same_event(hook_registry, normal_event):
7070
hook_registry.add_callback(NormalTestEvent, callback2)
7171

7272
assert len(hook_registry._registered_callbacks[NormalTestEvent]) == 2
73-
assert callback1 in hook_registry._registered_callbacks[NormalTestEvent]
74-
assert callback2 in hook_registry._registered_callbacks[NormalTestEvent]
73+
assert any(e.callback is callback1 for e in hook_registry._registered_callbacks[NormalTestEvent])
74+
assert any(e.callback is callback2 for e in hook_registry._registered_callbacks[NormalTestEvent])
7575

7676

7777
def test_add_hook(hook_registry):

strands-py/tests/strands/agent/test_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2629,7 +2629,7 @@ def test_agent_add_hook_delegates_to_hooks_add_callback():
26292629
# Spy on the hooks.add_callback method
26302630
with unittest.mock.patch.object(agent.hooks, "add_callback") as mock_add_callback:
26312631
agent.add_hook(callback, BeforeInvocationEvent)
2632-
mock_add_callback.assert_called_once_with(BeforeInvocationEvent, callback)
2632+
mock_add_callback.assert_called_once_with(BeforeInvocationEvent, callback, order=0)
26332633

26342634

26352635
@pytest.mark.asyncio

0 commit comments

Comments
 (0)