diff --git a/python/packages/azurefunctions/tests/test_func_utils.py b/python/packages/azurefunctions/tests/test_func_utils.py index 63f0af0182..9155bad33e 100644 --- a/python/packages/azurefunctions/tests/test_func_utils.py +++ b/python/packages/azurefunctions/tests/test_func_utils.py @@ -232,6 +232,7 @@ def test_roundtrip_agent_executor_response(self) -> None: original = AgentExecutorResponse( executor_id="test_exec", agent_response=AgentResponse(messages=[Message(role="assistant", text="Reply")]), + full_conversation=[Message(role="assistant", text="Reply")], ) encoded = serialize_value(original) decoded = deserialize_value(encoded) diff --git a/python/packages/azurefunctions/tests/test_workflow.py b/python/packages/azurefunctions/tests/test_workflow.py index 4c26c980b2..baba1c2602 100644 --- a/python/packages/azurefunctions/tests/test_workflow.py +++ b/python/packages/azurefunctions/tests/test_workflow.py @@ -212,6 +212,7 @@ def test_extract_from_agent_executor_response_with_text(self) -> None: response = AgentExecutorResponse( executor_id="exec", agent_response=AgentResponse(messages=[Message(role="assistant", text="Response text")]), + full_conversation=[Message(role="assistant", text="Response text")], ) result = _extract_message_content(response) @@ -228,6 +229,10 @@ def test_extract_from_agent_executor_response_with_messages(self) -> None: Message(role="assistant", text="Last message"), ] ), + full_conversation=[ + Message(role="user", text="First"), + Message(role="assistant", text="Last message"), + ], ) result = _extract_message_content(response) diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index ac2ebcf56f..462c3f8c64 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -4,7 +4,7 @@ import sys from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass -from typing import Any, cast +from typing import Any, Literal, cast from typing_extensions import Never @@ -57,7 +57,7 @@ class AgentExecutorResponse: executor_id: str agent_response: AgentResponse - full_conversation: list[Message] | None = None + full_conversation: list[Message] class AgentExecutor(Executor): @@ -83,6 +83,8 @@ def __init__( *, session: AgentSession | None = None, id: str | None = None, + context_mode: Literal["full", "last_agent", "custom"] | None = None, + context_filter: Callable[[list[Message]], list[Message]] | None = None, ): """Initialize the executor with a unique identifier. @@ -90,6 +92,16 @@ def __init__( agent: The agent to be wrapped by this executor. session: The session to use for running the agent. If None, a new session will be created. id: A unique identifier for the executor. If None, the agent's name will be used if available. + context_mode: Configuration for how the executor should manage conversation context upon + receiving an AgentExecutorResponse as input. Options: + - "full": append the full conversation (all prior messages + latest agent response) to the + cache for the agent run. This is the default mode. + - "last_agent": provide only the messages from the latest agent response as context for + the agent run. + - "custom": use the provided context_filter function to determine which messages to include + as context for the agent run. + context_filter: An optional function for filtering conversation context when context_mode is set + to "custom". """ # Prefer provided id; else use agent.name if present; else generate deterministic prefix exec_id = id or resolve_agent_id(agent) @@ -107,6 +119,14 @@ def __init__( # This tracks the full conversation after each run self._full_conversation: list[Message] = [] + # Context mode validation + self._context_mode = context_mode or "full" + self._context_filter = context_filter + if self._context_mode not in {"full", "last_agent", "custom"}: + raise ValueError("context_mode must be one of 'full', 'last_agent', or 'custom'.") + if self._context_mode == "custom" and not self._context_filter: + raise ValueError("context_filter must be provided when context_mode is set to 'custom'.") + @property def agent(self) -> SupportsAgentRun: """Get the underlying agent wrapped by this executor.""" @@ -129,6 +149,7 @@ async def run( run the agent and emit an AgentExecutorResponse downstream. """ self._cache.extend(request.messages) + if request.should_respond: await self._run_agent_and_emit(ctx) @@ -143,19 +164,27 @@ async def from_response( Strategy: treat the prior response's messages as the conversation state and immediately run the agent to produce a new response. """ - # Replace cache with full conversation if available, else fall back to agent_response messages. - source_messages = ( - prior.full_conversation if prior.full_conversation is not None else prior.agent_response.messages - ) - self._cache = list(source_messages) + if self._context_mode == "full": + self._cache.extend(prior.full_conversation) + elif self._context_mode == "last_agent": + self._cache.extend(prior.agent_response.messages) + else: + if not self._context_filter: + # This should never happen due to validation in __init__, but mypy doesn't track that well + raise ValueError("context_filter function must be provided for 'custom' context_mode.") + self._cache.extend(self._context_filter(prior.full_conversation)) + await self._run_agent_and_emit(ctx) @handler async def from_str( self, text: str, ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate] ) -> None: - """Accept a raw user prompt string and run the agent (one-shot).""" - self._cache = normalize_messages_input(text) + """Accept a raw user prompt string and run the agent. + + The new string input will be added to the cache which is used as the conversation context for the agent run. + """ + self._cache.extend(normalize_messages_input(text)) await self._run_agent_and_emit(ctx) @handler @@ -164,8 +193,11 @@ async def from_message( message: Message, ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate], ) -> None: - """Accept a single Message as input.""" - self._cache = normalize_messages_input(message) + """Accept a single Message as input. + + The new message will be added to the cache which is used as the conversation context for the agent run. + """ + self._cache.extend(normalize_messages_input(message)) await self._run_agent_and_emit(ctx) @handler @@ -174,8 +206,11 @@ async def from_messages( messages: list[str | Message], ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate], ) -> None: - """Accept a list of chat inputs (strings or Message) as conversation context.""" - self._cache = normalize_messages_input(messages) + """Accept a list of chat inputs (strings or Message) as conversation context. + + The new messages will be added to the cache which is used as the conversation context for the agent run. + """ + self._cache.extend(normalize_messages_input(messages)) await self._run_agent_and_emit(ctx) @response_handler @@ -249,24 +284,10 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: state: Checkpoint data dict """ cache_payload = state.get("cache") - if cache_payload: - try: - self._cache = cache_payload - except Exception as exc: - logger.warning("Failed to restore cache: %s", exc) - self._cache = [] - else: - self._cache = [] + self._cache = cache_payload or [] full_conversation_payload = state.get("full_conversation") - if full_conversation_payload: - try: - self._full_conversation = full_conversation_payload - except Exception as exc: - logger.warning("Failed to restore full conversation: %s", exc) - self._full_conversation = [] - else: - self._full_conversation = [] + self._full_conversation = full_conversation_payload or [] session_payload = state.get("agent_session") if session_payload: @@ -279,12 +300,10 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: self._session = self._agent.create_session() pending_requests_payload = state.get("pending_agent_requests") - if pending_requests_payload: - self._pending_agent_requests = pending_requests_payload + self._pending_agent_requests = pending_requests_payload or {} pending_responses_payload = state.get("pending_responses_to_agent") - if pending_responses_payload: - self._pending_responses_to_agent = pending_responses_payload + self._pending_responses_to_agent = pending_responses_payload or [] def reset(self) -> None: """Reset the internal cache of the executor.""" diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 059e683745..6298a8963d 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -16,12 +16,12 @@ Content, Message, ResponseStream, + WorkflowBuilder, WorkflowEvent, WorkflowRunState, ) from agent_framework._workflows._agent_executor import AgentExecutorResponse from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage -from agent_framework.orchestrations import SequentialBuilder if TYPE_CHECKING: from _pytest.logging import LogCaptureFixture @@ -139,7 +139,7 @@ async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() """AgentExecutor should call get_final_response() so stream result hooks execute.""" agent = _StreamingHookAgent(id="hook_agent", name="HookAgent") executor = AgentExecutor(agent, id="hook_exec") - workflow = SequentialBuilder(participants=[executor]).build() + workflow = WorkflowBuilder(start_executor=executor).build() output_events: list[Any] = [] async for event in workflow.run("run hook test", stream=True): @@ -154,8 +154,9 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: """Test that workflow checkpoint stores AgentExecutor's cache and session states and restores them correctly.""" storage = InMemoryCheckpointStorage() - # Create initial agent with a custom session - initial_agent = _CountingAgent(id="test_agent", name="TestAgent") + # Create two agents to form a two-step workflow + initial_agent_a = _CountingAgent(id="agent_a", name="AgentA") + initial_agent_b = _CountingAgent(id="agent_b", name="AgentB") initial_session = AgentSession() # Add some initial messages to the session state to verify session state persistence @@ -165,11 +166,12 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: ] initial_session.state["history"] = {"messages": initial_messages} - # Create AgentExecutor with the session - executor = AgentExecutor(initial_agent, session=initial_session) + # Create AgentExecutors — first executor gets the custom session + exec_a = AgentExecutor(initial_agent_a, id="exec_a", session=initial_session) + exec_b = AgentExecutor(initial_agent_b, id="exec_b") - # Build workflow with checkpointing enabled - wf = SequentialBuilder(participants=[executor], checkpoint_storage=storage).build() + # Build two-executor workflow with checkpointing enabled + wf = WorkflowBuilder(start_executor=exec_a, checkpoint_storage=storage).add_edge(exec_a, exec_b).build() # Run the workflow with a user message first_run_output: AgentExecutorResponse | None = None @@ -180,27 +182,25 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: break assert first_run_output is not None - assert initial_agent.call_count == 1 + assert initial_agent_a.call_count == 1 # Verify checkpoint was created checkpoints = await storage.list_checkpoints(workflow_name=wf.name) - assert len(checkpoints) >= 2, ( - "Expected at least 2 checkpoints. The first one is after the start executor, " - "and the second one is after the agent execution." - ) + assert len(checkpoints) >= 2, "Expected at least 2 checkpoints: one after exec_a and one after exec_b." - # Get the second checkpoint which should contain the state after processing - # the first message by the start executor in the sequential workflow + # Get the first checkpoint that contains exec_a's state (taken after exec_a completes, + # before exec_b runs) checkpoints.sort(key=lambda cp: cp.timestamp) - restore_checkpoint = checkpoints[1] + restore_checkpoint = next( + cp for cp in checkpoints if "_executor_state" in cp.state and "exec_a" in cp.state["_executor_state"] + ) # Verify checkpoint contains executor state with both cache and session - assert "_executor_state" in restore_checkpoint.state executor_states = restore_checkpoint.state["_executor_state"] assert isinstance(executor_states, dict) - assert executor.id in executor_states + assert exec_a.id in executor_states - executor_state = executor_states[executor.id] # type: ignore[index] + executor_state = executor_states[exec_a.id] # type: ignore[index] assert "cache" in executor_state, "Checkpoint should store executor cache state" assert "agent_session" in executor_state, "Checkpoint should store executor session state" @@ -213,19 +213,26 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: assert "pending_agent_requests" in executor_state assert "pending_responses_to_agent" in executor_state - # Create a new agent and executor for restoration + # Create new agents and executors for restoration # This simulates starting from a fresh state and restoring from checkpoint - restored_agent = _CountingAgent(id="test_agent", name="TestAgent") + restored_agent_a = _CountingAgent(id="agent_a", name="AgentA") + restored_agent_b = _CountingAgent(id="agent_b", name="AgentB") restored_session = AgentSession() - restored_executor = AgentExecutor(restored_agent, session=restored_session) - - # Verify the restored agent starts with a fresh state - assert restored_agent.call_count == 0 - - # Build new workflow with the restored executor - wf_resume = SequentialBuilder(participants=[restored_executor], checkpoint_storage=storage).build() + restored_exec_a = AgentExecutor(restored_agent_a, id="exec_a", session=restored_session) + restored_exec_b = AgentExecutor(restored_agent_b, id="exec_b") + + # Verify the restored agents start with a fresh state + assert restored_agent_a.call_count == 0 + assert restored_agent_b.call_count == 0 + + # Build new workflow with the restored executors + wf_resume = ( + WorkflowBuilder(start_executor=restored_exec_a, checkpoint_storage=storage) + .add_edge(restored_exec_a, restored_exec_b) + .build() + ) - # Resume from checkpoint + # Resume from checkpoint — exec_a already ran, so exec_b should run and produce output resumed_output: AgentExecutorResponse | None = None async for ev in wf_resume.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True): if ev.type == "output": @@ -239,7 +246,7 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: assert resumed_output is not None # Verify the restored executor's session state was restored - restored_session_obj = restored_executor._session # type: ignore[reportPrivateUsage] + restored_session_obj = restored_exec_a._session # type: ignore[reportPrivateUsage] assert restored_session_obj is not None assert restored_session_obj.session_id == initial_session.session_id @@ -306,7 +313,7 @@ async def test_agent_executor_run_with_session_kwarg_does_not_raise() -> None: """Passing session= via workflow.run() should not cause a duplicate-keyword TypeError (#4295).""" agent = _CountingAgent(id="session_kwarg_agent", name="SessionKwargAgent") executor = AgentExecutor(agent, id="session_kwarg_exec") - workflow = SequentialBuilder(participants=[executor]).build() + workflow = WorkflowBuilder(start_executor=executor).build() # This previously raised: TypeError: run() got multiple values for keyword argument 'session' result = await workflow.run("hello", session="user-supplied-value") @@ -318,7 +325,7 @@ async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() - """Passing stream= via workflow.run() kwargs should not cause a duplicate-keyword TypeError.""" agent = _CountingAgent(id="stream_kwarg_agent", name="StreamKwargAgent") executor = AgentExecutor(agent, id="stream_kwarg_exec") - workflow = SequentialBuilder(participants=[executor]).build() + workflow = WorkflowBuilder(start_executor=executor).build() # stream=True at workflow level triggers streaming mode (returns async iterable) events: list[WorkflowEvent] = [] @@ -378,7 +385,7 @@ async def test_agent_executor_run_with_messages_kwarg_does_not_raise() -> None: """Passing messages= via workflow.run() kwargs should not cause a duplicate-keyword TypeError.""" agent = _CountingAgent(id="messages_kwarg_agent", name="MessagesKwargAgent") executor = AgentExecutor(agent, id="messages_kwarg_exec") - workflow = SequentialBuilder(participants=[executor]).build() + workflow = WorkflowBuilder(start_executor=executor).build() result = await workflow.run("hello", messages=["stale"]) assert result is not None @@ -426,7 +433,7 @@ async def test_agent_executor_workflow_with_non_copyable_raw_representation() -> exec_a = AgentExecutor(agent_a, id="exec_a") exec_b = AgentExecutor(agent_b, id="exec_b") - workflow = SequentialBuilder(participants=[exec_a, exec_b]).build() + workflow = WorkflowBuilder(start_executor=exec_a).add_edge(exec_a, exec_b).build() events = await workflow.run("hello") completed = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] @@ -440,3 +447,194 @@ async def test_agent_executor_workflow_with_non_copyable_raw_representation() -> assert len(agent_responses) > 0 assert agent_responses[0].text == "reply from AgentA" assert agent_responses[0].raw_representation is raw + + +# --------------------------------------------------------------------------- +# Context mode tests +# --------------------------------------------------------------------------- + + +class _MessageCapturingAgent(BaseAgent): + """Agent that records the messages it received and returns a configurable reply.""" + + def __init__(self, *, reply_text: str = "reply", **kwargs: Any): + super().__init__(**kwargs) + self.reply_text = reply_text + self.last_messages: list[Message] = [] + + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: + captured: list[Message] = [] + if messages: + for m in messages: # type: ignore[union-attr] + if isinstance(m, Message): + captured.append(m) + elif isinstance(m, str): + captured.append(Message("user", [m])) + self.last_messages = captured + + if stream: + + async def _stream() -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate(contents=[Content.from_text(text=self.reply_text)]) + + return ResponseStream(_stream(), finalizer=AgentResponse.from_updates) + + async def _run() -> AgentResponse: + return AgentResponse(messages=[Message("assistant", [self.reply_text])]) + + return _run() + + +def test_context_mode_custom_requires_context_filter() -> None: + """context_mode='custom' without context_filter must raise ValueError.""" + agent = _CountingAgent(id="a", name="A") + with pytest.raises(ValueError, match="context_filter must be provided"): + AgentExecutor(agent, context_mode="custom") + + +def test_context_mode_custom_with_filter_succeeds() -> None: + """context_mode='custom' with a context_filter should not raise.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, context_mode="custom", context_filter=lambda msgs: msgs[-1:]) + assert executor._context_mode == "custom" # pyright: ignore[reportPrivateUsage] + assert executor._context_filter is not None # pyright: ignore[reportPrivateUsage] + + +def test_context_mode_defaults_to_full() -> None: + """Default context_mode should be 'full'.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent) + assert executor._context_mode == "full" # pyright: ignore[reportPrivateUsage] + + +def test_context_mode_invalid_value_raises() -> None: + """Invalid context_mode value should raise ValueError.""" + agent = _CountingAgent(id="a", name="A") + with pytest.raises(ValueError, match="context_mode must be one of"): + AgentExecutor(agent, context_mode="invalid_mode") # type: ignore + + +async def test_from_response_context_mode_full_passes_full_conversation() -> None: + """context_mode='full' (default) should pass full_conversation to the second agent.""" + first = _MessageCapturingAgent(id="first", name="First", reply_text="first reply") + second = _MessageCapturingAgent(id="second", name="Second", reply_text="second reply") + + exec_a = AgentExecutor(first, id="exec_a") + exec_b = AgentExecutor(second, id="exec_b", context_mode="full") + + wf = WorkflowBuilder(start_executor=exec_a).add_edge(exec_a, exec_b).build() + + async for ev in wf.run("hello", stream=True): + if ev.type == "status" and ev.state == WorkflowRunState.IDLE: + break + + # Second agent should see full conversation: [user("hello"), assistant("first reply")] + seen = second.last_messages + assert len(seen) == 2 + assert seen[0].role == "user" and "hello" in (seen[0].text or "") + assert seen[1].role == "assistant" and "first reply" in (seen[1].text or "") + + +async def test_from_response_context_mode_last_agent_passes_only_agent_messages() -> None: + """context_mode='last_agent' should pass only the previous agent's response messages.""" + first = _MessageCapturingAgent(id="first", name="First", reply_text="first reply") + second = _MessageCapturingAgent(id="second", name="Second", reply_text="second reply") + + exec_a = AgentExecutor(first, id="exec_a") + exec_b = AgentExecutor(second, id="exec_b", context_mode="last_agent") + + wf = WorkflowBuilder(start_executor=exec_a).add_edge(exec_a, exec_b).build() + + async for ev in wf.run("hello", stream=True): + if ev.type == "status" and ev.state == WorkflowRunState.IDLE: + break + + # Second agent should see only the assistant message from first: [assistant("first reply")] + seen = second.last_messages + assert len(seen) == 1 + assert seen[0].role == "assistant" and "first reply" in (seen[0].text or "") + + +async def test_from_response_context_mode_custom_uses_filter() -> None: + """context_mode='custom' should invoke context_filter on full_conversation.""" + first = _MessageCapturingAgent(id="first", name="First", reply_text="first reply") + second = _MessageCapturingAgent(id="second", name="Second", reply_text="second reply") + + # Custom filter: keep only user messages + def only_user_messages(msgs: list[Message]) -> list[Message]: + return [m for m in msgs if m.role == "user"] + + exec_a = AgentExecutor(first, id="exec_a") + exec_b = AgentExecutor(second, id="exec_b", context_mode="custom", context_filter=only_user_messages) + + wf = WorkflowBuilder(start_executor=exec_a).add_edge(exec_a, exec_b).build() + + async for ev in wf.run("hello", stream=True): + if ev.type == "status" and ev.state == WorkflowRunState.IDLE: + break + + # Second agent should see only user messages: [user("hello")] + seen = second.last_messages + assert len(seen) == 1 + assert seen[0].role == "user" and "hello" in (seen[0].text or "") + + +async def test_checkpoint_save_does_not_include_context_mode() -> None: + """on_checkpoint_save should not include context_mode in the saved state.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, context_mode="last_agent") + + state = await executor.on_checkpoint_save() + + assert "context_mode" not in state + assert "cache" in state + assert "agent_session" in state + + +async def test_checkpoint_restore_works_without_context_mode_in_state() -> None: + """on_checkpoint_restore should succeed when state does not contain context_mode.""" + agent = _CountingAgent(id="a", name="A") + executor = AgentExecutor(agent, context_mode="last_agent") + + # Simulate a checkpoint state without context_mode (as saved by the new code) + state: dict[str, Any] = { + "cache": [Message(role="user", text="cached msg")], + "full_conversation": [], + "agent_session": AgentSession().to_dict(), + "pending_agent_requests": {}, + "pending_responses_to_agent": [], + } + + await executor.on_checkpoint_restore(state) + + cache = executor._cache # pyright: ignore[reportPrivateUsage] + assert len(cache) == 1 + assert cache[0].text == "cached msg" + # context_mode should remain as configured in the constructor, not changed by restore + assert executor._context_mode == "last_agent" # pyright: ignore[reportPrivateUsage] diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index db6dccd9fa..a42e94f39d 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -341,7 +341,7 @@ async def test_runner_emits_runner_completion_for_agent_response_without_targets await ctx.send_message( WorkflowMessage( - data=AgentExecutorResponse("agent", AgentResponse()), + data=AgentExecutorResponse("agent", AgentResponse(), []), source_id="agent", ) ) diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_concurrent.py b/python/packages/orchestrations/agent_framework_orchestrations/_concurrent.py index 062e87806c..d73b7e322b 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_concurrent.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_concurrent.py @@ -100,24 +100,21 @@ def _is_role(msg: Any, role: str) -> bool: assistant_replies: list[Message] = [] for r in results: - resp_messages = list(getattr(r.agent_response, "messages", []) or []) - conv = r.full_conversation if r.full_conversation is not None else resp_messages + resp_messages = list(r.agent_response.messages) logger.debug( f"Aggregating executor {getattr(r, 'executor_id', '')}: " - f"{len(resp_messages)} response msgs, {len(conv)} conversation msgs" + f"{len(resp_messages)} response msgs, {len(r.full_conversation)} conversation msgs" ) # Capture a single user prompt (first encountered across any conversation) if prompt_message is None: - found_user = next((m for m in conv if _is_role(m, "user")), None) - if found_user is not None: - prompt_message = found_user + prompt_message = next((m for m in r.full_conversation if _is_role(m, "user")), None) # Pick the final assistant message from the response; fallback to conversation search final_assistant = next((m for m in reversed(resp_messages) if _is_role(m, "assistant")), None) if final_assistant is None: - final_assistant = next((m for m in reversed(conv) if _is_role(m, "assistant")), None) + final_assistant = next((m for m in reversed(r.full_conversation) if _is_role(m, "assistant")), None) if final_assistant is not None: assistant_replies.append(final_assistant) diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_orchestration_request_info.py b/python/packages/orchestrations/agent_framework_orchestrations/_orchestration_request_info.py index 5e4a5d6a28..16950606dc 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_orchestration_request_info.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_orchestration_request_info.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from dataclasses import dataclass +from typing import Literal from agent_framework._agents import SupportsAgentRun from agent_framework._types import Message @@ -117,18 +118,25 @@ class AgentApprovalExecutor(WorkflowExecutor): agent's output or send the final response to down stream executors in the orchestration. """ - def __init__(self, agent: SupportsAgentRun) -> None: + def __init__( + self, + agent: SupportsAgentRun, + context_mode: Literal["full", "last_agent", "custom"] | None = None, + ) -> None: """Initialize the AgentApprovalExecutor. Args: agent: The agent protocol to use for generating responses. + context_mode: The mode for providing context to the agent. """ - super().__init__(workflow=self._build_workflow(agent), id=resolve_agent_id(agent), propagate_request=True) + self._context_mode: Literal["full", "last_agent", "custom"] | None = context_mode self._description = agent.description + super().__init__(workflow=self._build_workflow(agent), id=resolve_agent_id(agent), propagate_request=True) + def _build_workflow(self, agent: SupportsAgentRun) -> Workflow: """Build the internal workflow for the AgentApprovalExecutor.""" - agent_executor = AgentExecutor(agent) + agent_executor = AgentExecutor(agent, context_mode=self._context_mode) request_info_executor = AgentRequestInfoExecutor(id="agent_request_info_executor") return ( diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_sequential.py b/python/packages/orchestrations/agent_framework_orchestrations/_sequential.py index 5ef4f7fe8c..1ccfed8f49 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_sequential.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_sequential.py @@ -38,7 +38,7 @@ import logging from collections.abc import Sequence -from typing import Any +from typing import Any, Literal from agent_framework import Message, SupportsAgentRun from agent_framework._workflows._agent_executor import ( @@ -143,6 +143,7 @@ def __init__( *, participants: Sequence[SupportsAgentRun | Executor], checkpoint_storage: CheckpointStorage | None = None, + chain_only_agent_responses: bool = False, intermediate_outputs: bool = False, ) -> None: """Initialize the SequentialBuilder. @@ -150,10 +151,14 @@ def __init__( Args: participants: Sequence of agent or executor instances to run sequentially. checkpoint_storage: Optional checkpoint storage for enabling workflow state persistence. + chain_only_agent_responses: If True, only agent responses are chained between agents. + By default, the full conversation context is passed to the next agent. This also applies + to Executor -> Agent transitions if the executor sends `AgentExecutorResponse`. intermediate_outputs: If True, enables intermediate outputs from agent participants. """ self._participants: list[SupportsAgentRun | Executor] = [] self._checkpoint_storage: CheckpointStorage | None = checkpoint_storage + self._chain_only_agent_responses: bool = chain_only_agent_responses self._request_info_enabled: bool = False self._request_info_filter: set[str] | None = None self._intermediate_outputs: bool = intermediate_outputs @@ -225,6 +230,10 @@ def _resolve_participants(self) -> list[Executor]: participants: list[Executor | SupportsAgentRun] = self._participants + context_mode: Literal["full", "last_agent", "custom"] | None = ( + "last_agent" if self._chain_only_agent_responses else None + ) + executors: list[Executor] = [] for p in participants: if isinstance(p, Executor): @@ -234,9 +243,9 @@ def _resolve_participants(self) -> list[Executor]: not self._request_info_filter or resolve_agent_id(p) in self._request_info_filter ): # Handle request info enabled agents - executors.append(AgentApprovalExecutor(p)) + executors.append(AgentApprovalExecutor(p, context_mode=context_mode)) else: - executors.append(AgentExecutor(p)) + executors.append(AgentExecutor(p, context_mode=context_mode)) else: raise TypeError(f"Participants must be SupportsAgentRun or Executor instances. Got {type(p).__name__}.") diff --git a/python/packages/orchestrations/tests/test_orchestration_request_info.py b/python/packages/orchestrations/tests/test_orchestration_request_info.py index 7d0acbc945..d618efcbf1 100644 --- a/python/packages/orchestrations/tests/test_orchestration_request_info.py +++ b/python/packages/orchestrations/tests/test_orchestration_request_info.py @@ -117,6 +117,7 @@ async def test_request_info_handler(self): agent_response = AgentExecutorResponse( executor_id="test_agent", agent_response=agent_response, + full_conversation=agent_response.messages, ) ctx = MagicMock(spec=WorkflowContext) @@ -135,6 +136,7 @@ async def test_handle_request_info_response_with_messages(self): original_request = AgentExecutorResponse( executor_id="test_agent", agent_response=agent_response, + full_conversation=agent_response.messages, ) response = AgentRequestInfoResponse.from_strings(["Additional input"]) @@ -161,6 +163,7 @@ async def test_handle_request_info_response_approval(self): original_request = AgentExecutorResponse( executor_id="test_agent", agent_response=agent_response, + full_conversation=agent_response.messages, ) response = AgentRequestInfoResponse.approve() diff --git a/python/packages/orchestrations/tests/test_sequential.py b/python/packages/orchestrations/tests/test_sequential.py index 67bcc1bb9e..0f000ef254 100644 --- a/python/packages/orchestrations/tests/test_sequential.py +++ b/python/packages/orchestrations/tests/test_sequential.py @@ -1,18 +1,20 @@ # Copyright (c) Microsoft. All rights reserved. from collections.abc import AsyncIterable, Awaitable -from typing import Any +from typing import Any, Literal, overload import pytest from agent_framework import ( AgentExecutorResponse, AgentResponse, AgentResponseUpdate, + AgentRunInputs, AgentSession, BaseAgent, Content, Executor, Message, + ResponseStream, TypeCompatibilityError, WorkflowContext, WorkflowRunState, @@ -25,26 +27,45 @@ class _EchoAgent(BaseAgent): """Simple agent that appends a single assistant message with its name.""" - def run( # type: ignore[override] + @overload + def run( self, - messages: str | Message | list[str] | list[Message] | None = None, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + + def run( + self, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: if stream: - return self._run_stream() + + async def _stream() -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate(contents=[Content.from_text(text=f"{self.name} reply")]) + + return ResponseStream(_stream(), finalizer=AgentResponse.from_updates) async def _run() -> AgentResponse: return AgentResponse(messages=[Message("assistant", [f"{self.name} reply"])]) return _run() - async def _run_stream(self) -> AsyncIterable[AgentResponseUpdate]: - # Minimal async generator with one assistant update - yield AgentResponseUpdate(contents=[Content.from_text(text=f"{self.name} reply")]) - class _SummarizerExec(Executor): """Custom executor that summarizes by appending a short assistant message.""" @@ -251,3 +272,121 @@ async def test_sequential_builder_reusable_after_build_with_participants() -> No assert builder._participants[0] is a1 # type: ignore assert builder._participants[1] is a2 # type: ignore + + +# --------------------------------------------------------------------------- +# chain_only_agent_responses tests +# --------------------------------------------------------------------------- + + +class _CapturingAgent(BaseAgent): + """Agent that records the messages it received and returns a configurable reply.""" + + def __init__(self, *, reply_text: str = "reply", **kwargs: Any): + super().__init__(**kwargs) + self.reply_text = reply_text + self.last_messages: list[Message] = [] + + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: + captured: list[Message] = [] + if messages: + for m in messages: # type: ignore[union-attr] + if isinstance(m, Message): + captured.append(m) + elif isinstance(m, str): + captured.append(Message("user", [m])) + self.last_messages = captured + + if stream: + + async def _stream() -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate(contents=[Content.from_text(text=self.reply_text)]) + + return ResponseStream(_stream(), finalizer=AgentResponse.from_updates) + + async def _run() -> AgentResponse: + return AgentResponse(messages=[Message("assistant", [self.reply_text])]) + + return _run() + + +async def test_chain_only_agent_responses_false_passes_full_conversation() -> None: + """Default (chain_only_agent_responses=False) passes full conversation to the second agent.""" + a1 = _CapturingAgent(id="agent1", name="A1", reply_text="A1 reply") + a2 = _CapturingAgent(id="agent2", name="A2", reply_text="A2 reply") + + wf = SequentialBuilder(participants=[a1, a2], chain_only_agent_responses=False).build() + + async for ev in wf.run("hello", stream=True): + if ev.type == "status" and ev.state == WorkflowRunState.IDLE: + break + + # Second agent should see full conversation: [user("hello"), assistant("A1 reply")] + seen = a2.last_messages + assert len(seen) == 2 + assert seen[0].role == "user" and "hello" in (seen[0].text or "") + assert seen[1].role == "assistant" and "A1 reply" in (seen[1].text or "") + + +async def test_chain_only_agent_responses_true_passes_only_agent_messages() -> None: + """chain_only_agent_responses=True passes only the previous agent's response messages.""" + a1 = _CapturingAgent(id="agent1", name="A1", reply_text="A1 reply") + a2 = _CapturingAgent(id="agent2", name="A2", reply_text="A2 reply") + + wf = SequentialBuilder(participants=[a1, a2], chain_only_agent_responses=True).build() + + async for ev in wf.run("hello", stream=True): + if ev.type == "status" and ev.state == WorkflowRunState.IDLE: + break + + # Second agent should see only the assistant message: [assistant("A1 reply")] + seen = a2.last_messages + assert len(seen) == 1 + assert seen[0].role == "assistant" and "A1 reply" in (seen[0].text or "") + + +async def test_chain_only_agent_responses_three_agents() -> None: + """chain_only_agent_responses=True with three agents: each sees only the prior agent's reply.""" + a1 = _CapturingAgent(id="agent1", name="A1", reply_text="A1 reply") + a2 = _CapturingAgent(id="agent2", name="A2", reply_text="A2 reply") + a3 = _CapturingAgent(id="agent3", name="A3", reply_text="A3 reply") + + wf = SequentialBuilder(participants=[a1, a2, a3], chain_only_agent_responses=True).build() + + async for ev in wf.run("hello", stream=True): + if ev.type == "status" and ev.state == WorkflowRunState.IDLE: + break + + # a2 should see only A1's reply + assert len(a2.last_messages) == 1 + assert a2.last_messages[0].role == "assistant" and "A1 reply" in (a2.last_messages[0].text or "") + + # a3 should see only A2's reply + assert len(a3.last_messages) == 1 + assert a3.last_messages[0].role == "assistant" and "A2 reply" in (a3.last_messages[0].text or "") diff --git a/python/samples/03-workflows/orchestrations/sequential_chain_only_agent_responses.py b/python/samples/03-workflows/orchestrations/sequential_chain_only_agent_responses.py new file mode 100644 index 0000000000..2bef81ebe5 --- /dev/null +++ b/python/samples/03-workflows/orchestrations/sequential_chain_only_agent_responses.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import os + +from agent_framework import AgentResponseUpdate +from agent_framework.azure import AzureOpenAIResponsesClient +from agent_framework.orchestrations import SequentialBuilder +from azure.identity import AzureCliCredential +from dotenv import load_dotenv + +""" +Sample: Sequential workflow with chain_only_agent_responses=True + +Demonstrates SequentialBuilder with `chain_only_agent_responses=True`, which passes +only the previous agent's response (not the full conversation history) to the next +agent. This is useful when each agent should focus solely on refining or transforming +the prior agent's output without being influenced by earlier turns. + +In this sample, a writer agent produces a draft tagline, a translator agent translates +it into French (seeing only the writer's output, not the original user prompt), and a +reviewer agent evaluates the translation (seeing only the translator's output). + +Compare with `sequential_agents.py`, which uses the default behavior where the full +conversation context is passed to each agent. + +Prerequisites: +- AZURE_AI_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint. +- Azure OpenAI configured for AzureOpenAIResponsesClient with required environment variables. +- Authentication via azure-identity. Use AzureCliCredential and run az login before executing the sample. +""" + +# Load environment variables from .env file +load_dotenv() + + +async def main() -> None: + # 1) Create agents + client = AzureOpenAIResponsesClient( + project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], + deployment_name=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], + credential=AzureCliCredential(), + ) + + writer = client.as_agent( + instructions="You are a concise copywriter. Provide a single, punchy marketing sentence based on the prompt.", + name="writer", + ) + + translator = client.as_agent( + instructions="You are a translator. Translate the given text into French. Output only the translation.", + name="translator", + ) + + reviewer = client.as_agent( + instructions="You are a reviewer. Evaluate the quality of the marketing tagline.", + name="reviewer", + ) + + # 2) Build sequential workflow: writer -> translator -> reviewer + # chain_only_agent_responses=True means each agent sees only the previous agent's reply, + # not the full conversation history. + workflow = SequentialBuilder( + participants=[writer, translator, reviewer], + chain_only_agent_responses=True, + intermediate_outputs=True, + ).build() + + # 3) Run and collect outputs + last_agent: str | None = None + async for event in workflow.run("Write a tagline for a budget-friendly eBike.", stream=True): + if event.type == "output" and isinstance(event.data, AgentResponseUpdate): + if event.data.author_name != last_agent: + last_agent = event.data.author_name + print() + print(f"{last_agent}: ", end="", flush=True) + print(event.data.text, end="", flush=True) + + +if __name__ == "__main__": + asyncio.run(main())