diff --git a/examples/langgraph/agent_executor.py b/examples/langgraph/agent_executor.py index add3c04..4bac9d3 100644 --- a/examples/langgraph/agent_executor.py +++ b/examples/langgraph/agent_executor.py @@ -1,5 +1,5 @@ -from agent import CurrencyAgent # type: ignore[import-untyped] -from typing_extensions import override +from agent import CurrencyAgent # type: ignore[import-untyped] + from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events.event_queue import EventQueue from a2a.types import ( @@ -17,7 +17,6 @@ class CurrencyAgentExecutor(AgentExecutor): def __init__(self): self.agent = CurrencyAgent() - @override async def execute( self, context: RequestContext, @@ -89,7 +88,6 @@ async def execute( ) ) - @override async def cancel( self, context: RequestContext, event_queue: EventQueue ) -> None: diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index eb8de0a..fa6d7ea 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -16,7 +16,6 @@ EventQueue, InMemoryQueueManager, QueueManager, - TaskQueueExists, ) from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.tasks import ( @@ -212,6 +211,15 @@ async def on_message_send( ) = await result_aggregator.consume_and_break_on_interrupt(consumer) if not result: raise ServerError(error=InternalError()) + + if isinstance(result, Task) and task_id != result.id: + logger.error( + f'Agent generated task_id={result.id} does not match the RequestContext task_id={task_id}.' + ) + raise ServerError( + InternalError(message='Task ID mismatch in agent response') + ) + finally: if interrupted: # TODO: Track this disconnected cleanup task. @@ -278,27 +286,27 @@ async def on_message_send_stream( consumer = EventConsumer(queue) producer_task.add_done_callback(consumer.agent_task_callback) async for event in result_aggregator.consume_and_emit(consumer): - if isinstance(event, Task) and task_id != event.id: - logger.warning( - f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.' - ) - try: - created_task: Task = event - await self._queue_manager.add(created_task.id, queue) - task_id = created_task.id - except TaskQueueExists: - logging.info( - 'Multiple Task objects created in event stream.' + if isinstance(event, Task): + if task_id != event.id: + logger.error( + f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.' ) + raise ServerError( + InternalError( + message='Task ID mismatch in agent response' + ) + ) + if ( self._push_notifier and params.configuration and params.configuration.pushNotificationConfig ): await self._push_notifier.set_info( - created_task.id, + task_id, params.configuration.pushNotificationConfig, ) + if self._push_notifier and task_id: latest_task = await result_aggregator.current_result if isinstance(latest_task, Task): diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index 6eb680a..431c54b 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -8,7 +8,8 @@ import httpx import pytest -from a2a.server.agent_execution import AgentExecutor + +from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.agent_execution.request_context_builder import ( RequestContextBuilder, ) @@ -59,6 +60,7 @@ TaskStatusUpdateEvent, TextPart, UnsupportedOperationError, + InternalError, ) from a2a.utils.errors import ServerError @@ -188,7 +190,12 @@ async def test_on_cancel_task_not_found(self) -> None: mock_task_store.get.assert_called_once_with('nonexistent_id') mock_agent_executor.cancel.assert_not_called() - async def test_on_message_new_message_success(self) -> None: + @patch( + 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' + ) + async def test_on_message_new_message_success( + self, _mock_builder_build: AsyncMock + ) -> None: mock_agent_executor = AsyncMock(spec=AgentExecutor) mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( @@ -199,6 +206,14 @@ async def test_on_message_new_message_success(self) -> None: mock_task_store.get.return_value = mock_task mock_agent_executor.execute.return_value = None + _mock_builder_build.return_value = RequestContext( + request=MagicMock(), + task_id='task_123', + context_id='session-xyz', + task=None, + related_tasks=None, + ) + async def streaming_coro(): yield mock_task @@ -284,15 +299,28 @@ async def streaming_coro(): assert response.root.error == UnsupportedOperationError() # type: ignore mock_agent_executor.execute.assert_called_once() - async def test_on_message_stream_new_message_success(self) -> None: + @patch( + 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' + ) + async def test_on_message_stream_new_message_success( + self, _mock_builder_build: AsyncMock + ) -> None: mock_agent_executor = AsyncMock(spec=AgentExecutor) mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( mock_agent_executor, mock_task_store ) - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) + self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) handler = JSONRPCHandler(self.mock_agent_card, request_handler) + _mock_builder_build.return_value = RequestContext( + request=MagicMock(), + task_id='task_123', + context_id='session-xyz', + task=None, + related_tasks=None, + ) + events: list[Any] = [ Task(**MINIMAL_TASK), TaskArtifactUpdateEvent( @@ -467,8 +495,11 @@ async def test_get_push_notification_success(self) -> None: ) assert get_response.root.result == task_push_config # type: ignore + @patch( + 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' + ) async def test_on_message_stream_new_message_send_push_notification_success( - self, + self, _mock_builder_build: AsyncMock ) -> None: mock_agent_executor = AsyncMock(spec=AgentExecutor) mock_task_store = AsyncMock(spec=TaskStore) @@ -480,6 +511,13 @@ async def test_on_message_stream_new_message_send_push_notification_success( self.mock_agent_card.capabilities = AgentCapabilities( streaming=True, pushNotifications=True ) + _mock_builder_build.return_value = RequestContext( + request=MagicMock(), + task_id='task_123', + context_id='session-xyz', + task=None, + related_tasks=None, + ) handler = JSONRPCHandler(self.mock_agent_card, request_handler) events: list[Any] = [ @@ -738,7 +776,8 @@ async def test_on_get_push_notification_no_push_notifier(self) -> None: # Assert self.assertIsInstance(response.root, JSONRPCErrorResponse) - self.assertEqual(response.root.error, UnsupportedOperationError()) + self.assertEqual(response.root.error, UnsupportedOperationError()) # type: ignore + async def test_on_set_push_notification_no_push_notifier(self) -> None: """Test set_push_notification with no push notifier configured.""" @@ -771,7 +810,8 @@ async def test_on_set_push_notification_no_push_notifier(self) -> None: # Assert self.assertIsInstance(response.root, JSONRPCErrorResponse) - self.assertEqual(response.root.error, UnsupportedOperationError()) + self.assertEqual(response.root.error, UnsupportedOperationError()) # type: ignore + async def test_on_message_send_internal_error(self) -> None: """Test on_message_send with an internal error.""" @@ -800,7 +840,8 @@ async def raise_server_error(*args, **kwargs): # Assert self.assertIsInstance(response.root, JSONRPCErrorResponse) - self.assertIsInstance(response.root.error, InternalError) + self.assertIsInstance(response.root.error, InternalError) # type: ignore + async def test_on_message_stream_internal_error(self) -> None: """Test on_message_send_stream with an internal error.""" @@ -906,3 +947,66 @@ async def consume_raises_error(*args, **kwargs): # Assert self.assertIsInstance(response.root, JSONRPCErrorResponse) self.assertEqual(response.root.error, UnsupportedOperationError()) + + async def test_on_message_send_task_id_mismatch(self) -> None: + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + mock_task = Task(**MINIMAL_TASK) + mock_task_store.get.return_value = mock_task + mock_agent_executor.execute.return_value = None + + async def streaming_coro(): + yield mock_task + + with patch( + 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', + return_value=streaming_coro(), + ): + request = SendMessageRequest( + id='1', + params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)), + ) + response = await handler.on_message_send(request) + assert mock_agent_executor.execute.call_count == 1 + self.assertIsInstance(response.root, JSONRPCErrorResponse) + self.assertIsInstance(response.root.error, InternalError) # type: ignore + + async def test_on_message_stream_task_id_mismatch(self) -> None: + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + + self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + events: list[Any] = [Task(**MINIMAL_TASK)] + + async def streaming_coro(): + for event in events: + yield event + + with patch( + 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', + return_value=streaming_coro(), + ): + mock_task_store.get.return_value = None + mock_agent_executor.execute.return_value = None + request = SendStreamingMessageRequest( + id='1', + params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)), + ) + response = handler.on_message_send_stream(request) + assert isinstance(response, AsyncGenerator) + collected_events: list[Any] = [] + async for event in response: + collected_events.append(event) + assert len(collected_events) == 1 + self.assertIsInstance( + collected_events[0].root, JSONRPCErrorResponse + ) + self.assertIsInstance(collected_events[0].root.error, InternalError)