|
| 1 | +"""Delivery primitives for context injection. |
| 2 | +
|
| 3 | +These fold just-in-time text into the latest user message *ephemerally* — the model sees the |
| 4 | +augmented input for one call while the agent's durable history is never touched. Reach injection |
| 5 | +through the ``ContextInjector`` plugin or the ``MemoryManager`` rather than these primitives |
| 6 | +directly. |
| 7 | +""" |
| 8 | + |
| 9 | +from __future__ import annotations |
| 10 | + |
| 11 | +import inspect |
| 12 | +import logging |
| 13 | +from collections.abc import Awaitable, Callable |
| 14 | +from dataclasses import replace |
| 15 | +from typing import TYPE_CHECKING, Any, Protocol |
| 16 | + |
| 17 | +from .types import InjectionContext, InjectionTriggerPredicate |
| 18 | + |
| 19 | +if TYPE_CHECKING: |
| 20 | + from .._middleware.stages import InvokeModelContext |
| 21 | + from ..types.content import ContentBlock, Message, Messages |
| 22 | + |
| 23 | +logger = logging.getLogger(__name__) |
| 24 | + |
| 25 | + |
| 26 | +class RenderContentCallback(Protocol): |
| 27 | + """Renders the text to fold into the latest user message for a model call. |
| 28 | +
|
| 29 | + Implemented by a plain function as well — the ``**kwargs`` tail lets the calling convention |
| 30 | + grow new keyword arguments without breaking existing callbacks. |
| 31 | + """ |
| 32 | + |
| 33 | + def __call__(self, context: InjectionContext, **kwargs: Any) -> str | None | Awaitable[str | None]: |
| 34 | + """Return the text to inject, ``None``/``""`` to skip, or an awaitable of either.""" |
| 35 | + ... |
| 36 | + |
| 37 | + |
| 38 | +# The text-rendering callback. The bare ``Callable`` arm keeps the happy path |
| 39 | +# (``lambda context: ...``) ergonomic; the ``RenderContentCallback`` arm is the forward-compatible |
| 40 | +# Protocol for callers that opt into future keyword arguments. A callback that raises fails open |
| 41 | +# (injection is skipped, the model call proceeds). |
| 42 | +RenderContent = Callable[[InjectionContext], "str | None | Awaitable[str | None]"] | RenderContentCallback |
| 43 | + |
| 44 | + |
| 45 | +def create_injection_middleware( |
| 46 | + render_content: RenderContent, |
| 47 | + *, |
| 48 | + trigger: InjectionTriggerPredicate | None = None, |
| 49 | +) -> Callable[[InvokeModelContext], Awaitable[InvokeModelContext]]: |
| 50 | + """Build an ``InvokeModelStage.Input`` handler that folds injected text into the conversation. |
| 51 | +
|
| 52 | + The handler folds ``render_content``'s text into the latest user message, ephemerally: the |
| 53 | + model sees the augmented input for this one call while the agent's durable history is |
| 54 | + never touched. The handler gates on the resolved trigger, asks ``render_content`` for the |
| 55 | + text, and returns a context with the folded messages. Anything that skips — the trigger |
| 56 | + not firing, ``render_content`` returning empty, or any callback raising — returns the |
| 57 | + context unchanged so the model call proceeds (fail open). The injected text never enters |
| 58 | + durable history because the input phase only rewrites the per-call context, not the |
| 59 | + agent's stored messages. |
| 60 | +
|
| 61 | + Args: |
| 62 | + render_content: Renders the text to inject for this call. Sync or async. |
| 63 | + trigger: When to inject. An ``InjectionTrigger`` name selects a built-in policy |
| 64 | + (``"userTurn"`` — default — or ``"everyTurn"``); a predicate over the |
| 65 | + ``InjectionContext`` is the escape hatch. Defaults to ``"userTurn"``. |
| 66 | +
|
| 67 | + Returns: |
| 68 | + An ``InvokeModelStage.Input`` handler that returns a (possibly) folded context. |
| 69 | + """ |
| 70 | + resolved_trigger = resolve_trigger(trigger) |
| 71 | + |
| 72 | + async def handler(context: InvokeModelContext) -> InvokeModelContext: |
| 73 | + agent = context.agent |
| 74 | + # Hand the callback its own list, so a callback that reorders/appends cannot perturb the |
| 75 | + # per-call context. The message dicts are shared, but the upstream InvokeModelContext is |
| 76 | + # already a defensive copy of agent state, so durable history is safe regardless. |
| 77 | + injection_context = InjectionContext(messages=list(context.messages), state=agent.state, agent=agent) |
| 78 | + |
| 79 | + if not resolved_trigger(injection_context): |
| 80 | + return context |
| 81 | + |
| 82 | + try: |
| 83 | + text = render_content(injection_context) |
| 84 | + if inspect.isawaitable(text): |
| 85 | + text = await text |
| 86 | + except Exception as error: # noqa: BLE001 - fail open: a bad callback must not abort the model call. |
| 87 | + logger.warning("reason=<%s> | injection render_content raised | skipping injection", error) |
| 88 | + return context |
| 89 | + |
| 90 | + if text is None or not text.strip(): |
| 91 | + return context |
| 92 | + |
| 93 | + return replace(context, messages=fold_into_last_user_message(context.messages, text)) |
| 94 | + |
| 95 | + return handler |
| 96 | + |
| 97 | + |
| 98 | +def resolve_trigger(trigger: InjectionTriggerPredicate | None) -> Callable[[InjectionContext], bool]: |
| 99 | + """Resolve an ``InjectionTrigger`` name or predicate into a single gate predicate. |
| 100 | +
|
| 101 | + ``"userTurn"`` maps to ``is_user_turn`` (over ``context.messages``); ``"everyTurn"`` to an |
| 102 | + always-true gate; a user-supplied predicate is wrapped so that a raise fails open (logs and |
| 103 | + skips injection rather than aborting the model call). |
| 104 | +
|
| 105 | + Args: |
| 106 | + trigger: An ``InjectionTrigger`` name, a predicate, or ``None`` (defaults to ``"userTurn"``). |
| 107 | +
|
| 108 | + Returns: |
| 109 | + A predicate that, given the ``InjectionContext``, returns whether to inject this call. |
| 110 | + """ |
| 111 | + if trigger is None or trigger == "userTurn": |
| 112 | + return lambda context: is_user_turn(context.messages) |
| 113 | + if trigger == "everyTurn": |
| 114 | + return lambda context: True |
| 115 | + |
| 116 | + predicate = trigger |
| 117 | + |
| 118 | + def guarded(context: InjectionContext) -> bool: |
| 119 | + try: |
| 120 | + return predicate(context) |
| 121 | + except Exception as error: # noqa: BLE001 - fail open: a bad predicate must not abort the model call. |
| 122 | + logger.warning("reason=<%s> | injection trigger raised | skipping injection", error) |
| 123 | + return False |
| 124 | + |
| 125 | + return guarded |
| 126 | + |
| 127 | + |
| 128 | +def is_user_turn(messages: Messages) -> bool: |
| 129 | + """Whether the latest message is a fresh user ask: a ``user`` message carrying no tool result. |
| 130 | +
|
| 131 | + This is the ``"userTurn"`` policy — it distinguishes a new chat ask from an autonomous |
| 132 | + tool-result turn. |
| 133 | +
|
| 134 | + Args: |
| 135 | + messages: The current conversation, as data. |
| 136 | +
|
| 137 | + Returns: |
| 138 | + ``True`` when the latest message is a plain user ask, otherwise ``False``. |
| 139 | + """ |
| 140 | + if not messages: |
| 141 | + return False |
| 142 | + last = messages[-1] |
| 143 | + return last["role"] == "user" and not any("toolResult" in block for block in last["content"]) |
| 144 | + |
| 145 | + |
| 146 | +def fold_into_last_user_message(messages: Messages, text: str) -> Messages: |
| 147 | + """Fold ``text`` into the most recent ``user`` message as a text block, returning a NEW list. |
| 148 | +
|
| 149 | + Folding into the existing user message (rather than inserting a standalone message) keeps |
| 150 | + role alternation valid in both chat and the autonomous tool loop. The block is placed to |
| 151 | + keep the message valid for the model: |
| 152 | +
|
| 153 | + - A plain user ask: the text is **prepended**, leaving the user's own ask in the recency |
| 154 | + slot — the last thing the model reads. |
| 155 | + - A tool-result turn (the message carries a tool result block): the text is **appended**, |
| 156 | + because providers require the tool result to be the first content block in the turn that |
| 157 | + answers a tool use. |
| 158 | +
|
| 159 | + The input list and its messages are never mutated. When there is no ``user`` message, the |
| 160 | + input list is returned unchanged. |
| 161 | +
|
| 162 | + Args: |
| 163 | + messages: The conversation to fold into. |
| 164 | + text: The text to fold into the most recent user message. |
| 165 | +
|
| 166 | + Returns: |
| 167 | + A new list with the folded message, or the input list when there is no user message. |
| 168 | + """ |
| 169 | + target_index = -1 |
| 170 | + for index in range(len(messages) - 1, -1, -1): |
| 171 | + if messages[index]["role"] == "user": |
| 172 | + target_index = index |
| 173 | + break |
| 174 | + if target_index < 0: |
| 175 | + return messages |
| 176 | + |
| 177 | + target = messages[target_index] |
| 178 | + injected: ContentBlock = {"text": text} |
| 179 | + # A tool result must stay the first block in the turn that answers a tool use, so append |
| 180 | + # rather than prepend when the target carries one. |
| 181 | + has_tool_result = any("toolResult" in block for block in target["content"]) |
| 182 | + content = [*target["content"], injected] if has_tool_result else [injected, *target["content"]] |
| 183 | + |
| 184 | + folded: Message = {"role": target["role"], "content": content} |
| 185 | + if "metadata" in target: |
| 186 | + folded["metadata"] = target["metadata"] |
| 187 | + |
| 188 | + result = list(messages) |
| 189 | + result[target_index] = folded |
| 190 | + return result |
0 commit comments