Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 105 additions & 10 deletions python/packages/core/agent_framework/_workflows/_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from typing import Any, cast

from typing_extensions import Never
from typing_extensions import Literal, Never

from agent_framework import Content

Expand Down Expand Up @@ -57,7 +57,71 @@ class AgentExecutorResponse:

executor_id: str
agent_response: AgentResponse
full_conversation: list[Message] | None = None
full_conversation: list[Message]


@dataclass
class ContextMode:
"""Configuration for how AgentExecutor should manage conversation context when receiving new messages.

Attributes:
filter_mode: Determines which incoming messages are included in the context provided to the agent.
- "full": Include all incoming messages (default).
- "last_agent": Include only messages from the agent response.
- "custom": Use the provided messages_filter callable to filter messages.
retain_cache: A flag indicating whether the executor should retain its internal cache upon receiving new
messages. If False, the cache will be cleared before processing new messages.
messages_filter: A callable that takes a list of incoming messages and returns a filtered list
to be used as context for the agent. By default, all incoming messages are included. This is only used
if filter_mode is set to "custom".

Note:
1. `AgentExecutorRequest` is exempt from the `ContextMode` filtering behavior: its messages are always appended
to the cache unfiltered, and the `should_respond` flag controls whether those messages trigger an agent run.
2. The cache stores messages received from other executors. It stores the messages until the agent is run,
after which the cache is cleared and the messages are added to the full conversation context. It's important
for the executor to be able to stage messages because in many multiple agent workflows, one agent may receive
many messages from different executors before it runs.
"""

filter_mode: Literal["full", "last_agent", "custom"]
retain_cache: bool
messages_filter: Callable[[list[Message]], list[Message]] | None = None

def __init__(
self,
filter_mode: Literal["full", "last_agent", "custom"] = "full",
retain_cache: bool = True,
messages_filter: Callable[[list[Message]], list[Message]] | None = None,
):
self.filter_mode = filter_mode
self.retain_cache = retain_cache
self.messages_filter = messages_filter
if filter_mode == "custom" and messages_filter is None:
raise ValueError("messages_filter must be provided when filter_mode is 'custom'.")

# Some common context modes
@staticmethod
def default() -> "ContextMode":
"""Default context mode that includes all incoming messages and retains cache."""
return ContextMode(filter_mode="full", retain_cache=True)

@staticmethod
def last_agent(retain_cache: bool = True) -> "ContextMode":
"""Context mode that includes only messages from the most recent agent response."""
return ContextMode(filter_mode="last_agent", retain_cache=retain_cache)

@staticmethod
def last_n(n: int, retain_cache: bool = True) -> "ContextMode":
"""Context mode that includes only the last n messages from the full conversation so far."""

def _last_n_messages(messages: list[Message]) -> list[Message]:
length = len(messages)
if length <= n:
return messages
return messages[length - n :]

return ContextMode(filter_mode="custom", retain_cache=retain_cache, messages_filter=_last_n_messages)


class AgentExecutor(Executor):
Expand All @@ -83,13 +147,16 @@ def __init__(
*,
session: AgentSession | None = None,
id: str | None = None,
context_mode: ContextMode | None = None,
):
"""Initialize the executor with a unique identifier.

Args:
agent: The agent to be wrapped by this executor.
session: The session to use for running the agent. If None, a new session will be created.
id: A unique identifier for the executor. If None, the agent's name will be used if available.
context_mode: Configuration for how the executor should manage conversation context. If None,
defaults to ContextMode.default().
"""
# Prefer provided id; else use agent.name if present; else generate deterministic prefix
exec_id = id or resolve_agent_id(agent)
Expand All @@ -98,6 +165,7 @@ def __init__(
super().__init__(exec_id)
self._agent = agent
self._session = session or self._agent.create_session()
self._context_mode = context_mode or ContextMode.default()

self._pending_agent_requests: dict[str, Content] = {}
self._pending_responses_to_agent: list[Content] = []
Expand Down Expand Up @@ -143,19 +211,26 @@ async def from_response(
Strategy: treat the prior response's messages as the conversation state and
immediately run the agent to produce a new response.
"""
# Replace cache with full conversation if available, else fall back to agent_response messages.
source_messages = (
prior.full_conversation if prior.full_conversation is not None else prior.agent_response.messages
)
self._cache = list(source_messages)
if not self._context_mode.retain_cache:
self._cache.clear()

if self._context_mode.filter_mode == "full":
self._cache.extend(list(prior.full_conversation))
elif self._context_mode.filter_mode == "last_agent":
self._cache.extend(list(prior.agent_response.messages))
elif self._context_mode.filter_mode == "custom":
self._cache.extend(list(self._context_mode.messages_filter(prior.full_conversation))) # type: ignore

await self._run_agent_and_emit(ctx)

@handler
async def from_str(
self, text: str, ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate]
) -> None:
"""Accept a raw user prompt string and run the agent (one-shot)."""
self._cache = normalize_messages_input(text)
if not self._context_mode.retain_cache:
self._cache.clear()
self._cache.extend(normalize_messages_input(text))
await self._run_agent_and_emit(ctx)

@handler
Expand All @@ -165,7 +240,9 @@ async def from_message(
ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate],
) -> None:
"""Accept a single Message as input."""
self._cache = normalize_messages_input(message)
if not self._context_mode.retain_cache:
self._cache.clear()
self._cache.extend(normalize_messages_input(message))
await self._run_agent_and_emit(ctx)

@handler
Expand All @@ -175,7 +252,14 @@ async def from_messages(
ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate],
) -> None:
"""Accept a list of chat inputs (strings or Message) as conversation context."""
self._cache = normalize_messages_input(messages)
if not self._context_mode.retain_cache:
self._cache.clear()

normalized_messages = normalize_messages_input(messages)
if self._context_mode.filter_mode == "custom":
normalized_messages = self._context_mode.messages_filter(normalized_messages) # type: ignore

self._cache.extend(normalized_messages)
await self._run_agent_and_emit(ctx)

@response_handler
Expand Down Expand Up @@ -237,6 +321,7 @@ async def on_checkpoint_save(self) -> dict[str, Any]:
"cache": self._cache,
"full_conversation": self._full_conversation,
"agent_session": serialized_session,
"context_mode": self._context_mode,
"pending_agent_requests": self._pending_agent_requests,
"pending_responses_to_agent": self._pending_responses_to_agent,
}
Expand Down Expand Up @@ -278,6 +363,16 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
else:
self._session = self._agent.create_session()

context_mode_payload = state.get("context_mode")
if context_mode_payload:
try:
self._context_mode = context_mode_payload
except Exception as exc:
logger.warning("Failed to restore context mode: %s", exc)
self._context_mode = ContextMode.default()
else:
self._context_mode = ContextMode.default()

pending_requests_payload = state.get("pending_agent_requests")
if pending_requests_payload:
self._pending_agent_requests = pending_requests_payload
Expand Down
2 changes: 1 addition & 1 deletion python/packages/core/tests/workflow/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ async def test_runner_emits_runner_completion_for_agent_response_without_targets

await ctx.send_message(
WorkflowMessage(
data=AgentExecutorResponse("agent", AgentResponse()),
data=AgentExecutorResponse("agent", AgentResponse(), []),
source_id="agent",
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ async def test_request_info_handler(self):
agent_response = AgentExecutorResponse(
executor_id="test_agent",
agent_response=agent_response,
full_conversation=agent_response.messages,
)

ctx = MagicMock(spec=WorkflowContext)
Expand All @@ -135,6 +136,7 @@ async def test_handle_request_info_response_with_messages(self):
original_request = AgentExecutorResponse(
executor_id="test_agent",
agent_response=agent_response,
full_conversation=agent_response.messages,
)

response = AgentRequestInfoResponse.from_strings(["Additional input"])
Expand All @@ -161,6 +163,7 @@ async def test_handle_request_info_response_approval(self):
original_request = AgentExecutorResponse(
executor_id="test_agent",
agent_response=agent_response,
full_conversation=agent_response.messages,
)

response = AgentRequestInfoResponse.approve()
Expand Down
Loading