Skip to content

Commit 0cf9b28

Browse files
committed
feat: add memory injection (python)
1 parent 713a6de commit 0cf9b28

16 files changed

Lines changed: 1500 additions & 5 deletions

File tree

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Context injection for Strands Agents.
2+
3+
This package provides the configuration types for context injection — folding just-in-time text
4+
into the model input before a call without touching durable history. The delivery primitives
5+
(in ``message_injection``) are internal; reach injection through the ``ContextInjector`` plugin
6+
or the ``MemoryManager`` rather than using them directly.
7+
"""
8+
9+
from .types import InjectionConfig, InjectionContext, InjectionTrigger
10+
11+
__all__ = [
12+
"InjectionConfig",
13+
"InjectionContext",
14+
"InjectionTrigger",
15+
]
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""Delivery primitives for context injection.
2+
3+
These fold just-in-time text into the latest user message *ephemerally* — the model sees the
4+
augmented input for one call while the agent's durable history is never touched. Reach injection
5+
through the ``ContextInjector`` plugin or the ``MemoryManager`` rather than these primitives
6+
directly.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import inspect
12+
import logging
13+
from collections.abc import Awaitable, Callable
14+
from dataclasses import replace
15+
from typing import TYPE_CHECKING, Any, Protocol
16+
17+
from .types import InjectionContext, InjectionTriggerPredicate
18+
19+
if TYPE_CHECKING:
20+
from .._middleware.stages import InvokeModelContext
21+
from ..types.content import ContentBlock, Message, Messages
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
class RenderContentCallback(Protocol):
27+
"""Renders the text to fold into the latest user message for a model call.
28+
29+
Implemented by a plain function as well — the ``**kwargs`` tail lets the calling convention
30+
grow new keyword arguments without breaking existing callbacks.
31+
"""
32+
33+
def __call__(self, context: InjectionContext, **kwargs: Any) -> str | None | Awaitable[str | None]:
34+
"""Return the text to inject, ``None``/``""`` to skip, or an awaitable of either."""
35+
...
36+
37+
38+
# The text-rendering callback. The bare ``Callable`` arm keeps the happy path
39+
# (``lambda context: ...``) ergonomic; the ``RenderContentCallback`` arm is the forward-compatible
40+
# Protocol for callers that opt into future keyword arguments. A callback that raises fails open
41+
# (injection is skipped, the model call proceeds).
42+
RenderContent = Callable[[InjectionContext], "str | None | Awaitable[str | None]"] | RenderContentCallback
43+
44+
45+
def create_injection_middleware(
46+
render_content: RenderContent,
47+
*,
48+
trigger: InjectionTriggerPredicate | None = None,
49+
) -> Callable[[InvokeModelContext], Awaitable[InvokeModelContext]]:
50+
"""Build an ``InvokeModelStage.Input`` handler that folds injected text into the conversation.
51+
52+
The handler folds ``render_content``'s text into the latest user message, ephemerally: the
53+
model sees the augmented input for this one call while the agent's durable history is
54+
never touched. The handler gates on the resolved trigger, asks ``render_content`` for the
55+
text, and returns a context with the folded messages. Anything that skips — the trigger
56+
not firing, ``render_content`` returning empty, or any callback raising — returns the
57+
context unchanged so the model call proceeds (fail open). The injected text never enters
58+
durable history because the input phase only rewrites the per-call context, not the
59+
agent's stored messages.
60+
61+
Args:
62+
render_content: Renders the text to inject for this call. Sync or async.
63+
trigger: When to inject. An ``InjectionTrigger`` name selects a built-in policy
64+
(``"userTurn"`` — default — or ``"everyTurn"``); a predicate over the
65+
``InjectionContext`` is the escape hatch. Defaults to ``"userTurn"``.
66+
67+
Returns:
68+
An ``InvokeModelStage.Input`` handler that returns a (possibly) folded context.
69+
"""
70+
resolved_trigger = resolve_trigger(trigger)
71+
72+
async def handler(context: InvokeModelContext) -> InvokeModelContext:
73+
agent = context.agent
74+
# Hand the callback its own list, so a callback that reorders/appends cannot perturb the
75+
# per-call context. The message dicts are shared, but the upstream InvokeModelContext is
76+
# already a defensive copy of agent state, so durable history is safe regardless.
77+
injection_context = InjectionContext(messages=list(context.messages), state=agent.state, agent=agent)
78+
79+
if not resolved_trigger(injection_context):
80+
return context
81+
82+
try:
83+
text = render_content(injection_context)
84+
if inspect.isawaitable(text):
85+
text = await text
86+
except Exception as error: # noqa: BLE001 - fail open: a bad callback must not abort the model call.
87+
logger.warning("reason=<%s> | injection render_content raised | skipping injection", error)
88+
return context
89+
90+
if text is None or not text.strip():
91+
return context
92+
93+
return replace(context, messages=fold_into_last_user_message(context.messages, text))
94+
95+
return handler
96+
97+
98+
def resolve_trigger(trigger: InjectionTriggerPredicate | None) -> Callable[[InjectionContext], bool]:
99+
"""Resolve an ``InjectionTrigger`` name or predicate into a single gate predicate.
100+
101+
``"userTurn"`` maps to ``is_user_turn`` (over ``context.messages``); ``"everyTurn"`` to an
102+
always-true gate; a user-supplied predicate is wrapped so that a raise fails open (logs and
103+
skips injection rather than aborting the model call).
104+
105+
Args:
106+
trigger: An ``InjectionTrigger`` name, a predicate, or ``None`` (defaults to ``"userTurn"``).
107+
108+
Returns:
109+
A predicate that, given the ``InjectionContext``, returns whether to inject this call.
110+
"""
111+
if trigger is None or trigger == "userTurn":
112+
return lambda context: is_user_turn(context.messages)
113+
if trigger == "everyTurn":
114+
return lambda context: True
115+
116+
predicate = trigger
117+
118+
def guarded(context: InjectionContext) -> bool:
119+
try:
120+
return predicate(context)
121+
except Exception as error: # noqa: BLE001 - fail open: a bad predicate must not abort the model call.
122+
logger.warning("reason=<%s> | injection trigger raised | skipping injection", error)
123+
return False
124+
125+
return guarded
126+
127+
128+
def is_user_turn(messages: Messages) -> bool:
129+
"""Whether the latest message is a fresh user ask: a ``user`` message carrying no tool result.
130+
131+
This is the ``"userTurn"`` policy — it distinguishes a new chat ask from an autonomous
132+
tool-result turn.
133+
134+
Args:
135+
messages: The current conversation, as data.
136+
137+
Returns:
138+
``True`` when the latest message is a plain user ask, otherwise ``False``.
139+
"""
140+
if not messages:
141+
return False
142+
last = messages[-1]
143+
return last["role"] == "user" and not any("toolResult" in block for block in last["content"])
144+
145+
146+
def fold_into_last_user_message(messages: Messages, text: str) -> Messages:
147+
"""Fold ``text`` into the most recent ``user`` message as a text block, returning a NEW list.
148+
149+
Folding into the existing user message (rather than inserting a standalone message) keeps
150+
role alternation valid in both chat and the autonomous tool loop. The block is placed to
151+
keep the message valid for the model:
152+
153+
- A plain user ask: the text is **prepended**, leaving the user's own ask in the recency
154+
slot — the last thing the model reads.
155+
- A tool-result turn (the message carries a tool result block): the text is **appended**,
156+
because providers require the tool result to be the first content block in the turn that
157+
answers a tool use.
158+
159+
The input list and its messages are never mutated. When there is no ``user`` message, the
160+
input list is returned unchanged.
161+
162+
Args:
163+
messages: The conversation to fold into.
164+
text: The text to fold into the most recent user message.
165+
166+
Returns:
167+
A new list with the folded message, or the input list when there is no user message.
168+
"""
169+
target_index = -1
170+
for index in range(len(messages) - 1, -1, -1):
171+
if messages[index]["role"] == "user":
172+
target_index = index
173+
break
174+
if target_index < 0:
175+
return messages
176+
177+
target = messages[target_index]
178+
injected: ContentBlock = {"text": text}
179+
# A tool result must stay the first block in the turn that answers a tool use, so append
180+
# rather than prepend when the target carries one.
181+
has_tool_result = any("toolResult" in block for block in target["content"])
182+
content = [*target["content"], injected] if has_tool_result else [injected, *target["content"]]
183+
184+
folded: Message = {"role": target["role"], "content": content}
185+
if "metadata" in target:
186+
folded["metadata"] = target["metadata"]
187+
188+
result = list(messages)
189+
result[target_index] = folded
190+
return result
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Configuration types shared by injection consumers.
2+
3+
Consumed by the ``ContextInjector`` plugin and the ``MemoryManager``.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
from collections.abc import Callable
9+
from dataclasses import dataclass
10+
from typing import TYPE_CHECKING, Any, Literal, Protocol
11+
12+
if TYPE_CHECKING:
13+
from ..agent.agent import Agent
14+
from ..agent.state import AgentState
15+
from ..types.content import Messages
16+
17+
InjectionTrigger = Literal["userTurn", "everyTurn"]
18+
"""Determines when injection runs before a model call.
19+
20+
- ``"userTurn"``: only when the latest message is a fresh user ask (a ``user`` message with
21+
no tool result) — the common case for chat agents, where it keeps the user's ask the final
22+
message the model sees.
23+
- ``"everyTurn"``: before every model call, including mid-task tool-result turns — for
24+
autonomous agents that should consult injected context at each step.
25+
26+
For finer control, pass a predicate instead of a trigger name.
27+
"""
28+
29+
30+
@dataclass
31+
class InjectionContext:
32+
"""The context an injection consumer receives on each model call.
33+
34+
Passed to the ``render_content`` callback and to a predicate trigger.
35+
36+
Attributes:
37+
messages: The current conversation, as data.
38+
state: Durable agent state shared across calls, hooks, and tools — read what a tool
39+
stashed last turn.
40+
agent: The agent the injection is attached to (escape hatch for advanced consumers).
41+
"""
42+
43+
messages: Messages
44+
state: AgentState
45+
agent: Agent
46+
47+
48+
class TriggerCallback(Protocol):
49+
"""A predicate that decides whether to inject on a given model call.
50+
51+
Implemented by a plain function as well — the ``**kwargs`` tail lets the calling
52+
convention grow new keyword arguments without breaking existing predicates.
53+
"""
54+
55+
def __call__(self, context: InjectionContext, **kwargs: Any) -> bool:
56+
"""Return whether to inject this call, given the injection context."""
57+
...
58+
59+
60+
# A trigger name, or a predicate over the injection context. The bare ``Callable`` arm keeps the
61+
# happy path (``lambda context: ...``) ergonomic; the ``TriggerCallback`` arm is the forward-
62+
# compatible Protocol for callers that opt into future keyword arguments.
63+
InjectionTriggerPredicate = InjectionTrigger | Callable[[InjectionContext], bool] | TriggerCallback
64+
65+
66+
@dataclass
67+
class InjectionConfig:
68+
"""Configuration common to every injection consumer: when to inject.
69+
70+
What text to inject is a consumer concern, added by the configs that extend this one
71+
(e.g. ``MemoryInjectionConfig``).
72+
73+
Attributes:
74+
trigger: When injection runs. An ``InjectionTrigger`` name selects a built-in policy;
75+
a predicate is the escape hatch — it receives the ``InjectionContext`` and returns
76+
whether to inject this call. A predicate that raises fails open (injection is
77+
skipped, the model call proceeds). Defaults to ``"userTurn"``.
78+
"""
79+
80+
trigger: InjectionTriggerPredicate | None = None
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Minimal XML escaping for folding untrusted text into an XML-shaped block.
2+
3+
Memory entries and other injected content are frequently user-derived, so interpolating them
4+
raw into ``<entry>…</entry>`` both breaks the block structurally (a stray ``</entry>`` or
5+
``"``) and opens a stored-prompt-injection surface. These helpers are deliberately tiny —
6+
enough to keep a ``<memory>`` block well-formed, not a general-purpose serializer.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
12+
def escape_xml_text(value: str) -> str:
13+
"""Escape text content for placement between XML tags.
14+
15+
Escapes ``&`` first (so later replacements are not double-escaped), then ``<`` and ``>``.
16+
17+
Args:
18+
value: The raw text to escape.
19+
20+
Returns:
21+
The escaped text, safe to place in element content.
22+
"""
23+
return value.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
24+
25+
26+
def escape_xml_attr(value: str) -> str:
27+
"""Escape a value for placement inside a double-quoted XML attribute.
28+
29+
Applies the :func:`escape_xml_text` rules plus ``"`` and ``'``.
30+
31+
Args:
32+
value: The raw attribute value to escape.
33+
34+
Returns:
35+
The escaped value, safe to place inside a quoted attribute.
36+
"""
37+
return escape_xml_text(value).replace('"', "&quot;").replace("'", "&#39;")

strands-py/src/strands/memory/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
tools, and runs automatic background extraction.
66
"""
77

8+
from ..injection import InjectionConfig, InjectionContext, InjectionTrigger
89
from ..types.exceptions import AggregateMemoryError
910
from .extraction.model_extractor import ModelExtractor
1011
from .extraction.triggers import IntervalTrigger, InvocationTrigger
@@ -21,9 +22,12 @@
2122
from .memory_manager import MemoryManager
2223
from .types import (
2324
AddMessagesContext,
25+
InjectionFormatContext,
26+
InjectionQueryContext,
2427
MemoryAddOptions,
2528
MemoryAddToolConfig,
2629
MemoryEntry,
30+
MemoryInjectionConfig,
2731
MemoryManagerConfig,
2832
MemorySearchOptions,
2933
MemoryStore,
@@ -41,12 +45,18 @@
4145
"ExtractionTriggerContext",
4246
"Extractor",
4347
"ExtractorContext",
48+
"InjectionConfig",
49+
"InjectionContext",
50+
"InjectionFormatContext",
51+
"InjectionQueryContext",
52+
"InjectionTrigger",
4453
"IntervalTrigger",
4554
"InvocationTrigger",
4655
"MemoryAddOptions",
4756
"MemoryAddToolConfig",
4857
"MemoryContentBlockType",
4958
"MemoryEntry",
59+
"MemoryInjectionConfig",
5060
"MemoryManager",
5161
"MemoryManagerConfig",
5262
"MemoryMessageFilter",

0 commit comments

Comments
 (0)