Skip to content

fix: Throw exception for task_id mismatches #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 21, 2025
Merged
6 changes: 2 additions & 4 deletions examples/langgraph/agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from agent import CurrencyAgent # type: ignore[import-untyped]
from typing_extensions import override
from agent import CurrencyAgent # type: ignore[import-untyped]

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events.event_queue import EventQueue
from a2a.types import (
Expand All @@ -17,7 +17,6 @@ class CurrencyAgentExecutor(AgentExecutor):
def __init__(self):
self.agent = CurrencyAgent()

@override
async def execute(
self,
context: RequestContext,
Expand Down Expand Up @@ -89,7 +88,6 @@ async def execute(
)
)

@override
async def cancel(
self, context: RequestContext, event_queue: EventQueue
) -> None:
Expand Down
34 changes: 21 additions & 13 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
EventQueue,
InMemoryQueueManager,
QueueManager,
TaskQueueExists,
)
from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.server.tasks import (
Expand Down Expand Up @@ -212,6 +211,15 @@ async def on_message_send(
) = await result_aggregator.consume_and_break_on_interrupt(consumer)
if not result:
raise ServerError(error=InternalError())

if isinstance(result, Task) and task_id != result.id:
logger.error(
f'Agent generated task_id={result.id} does not match the RequestContext task_id={task_id}.'
)
raise ServerError(
InternalError(message='Task ID mismatch in agent response')
)

finally:
if interrupted:
# TODO: Track this disconnected cleanup task.
Expand Down Expand Up @@ -278,27 +286,27 @@ async def on_message_send_stream(
consumer = EventConsumer(queue)
producer_task.add_done_callback(consumer.agent_task_callback)
async for event in result_aggregator.consume_and_emit(consumer):
if isinstance(event, Task) and task_id != event.id:
logger.warning(
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
)
try:
created_task: Task = event
await self._queue_manager.add(created_task.id, queue)
task_id = created_task.id
except TaskQueueExists:
logging.info(
'Multiple Task objects created in event stream.'
if isinstance(event, Task):
if task_id != event.id:
logger.error(
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
)
raise ServerError(
InternalError(
message='Task ID mismatch in agent response'
)
)

if (
self._push_notifier
and params.configuration
and params.configuration.pushNotificationConfig
):
await self._push_notifier.set_info(
created_task.id,
task_id,
params.configuration.pushNotificationConfig,
)

if self._push_notifier and task_id:
latest_task = await result_aggregator.current_result
if isinstance(latest_task, Task):
Expand Down
120 changes: 112 additions & 8 deletions tests/server/request_handlers/test_jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import httpx
import pytest

from a2a.server.agent_execution import AgentExecutor

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.agent_execution.request_context_builder import (
RequestContextBuilder,
)
Expand Down Expand Up @@ -59,6 +60,7 @@
TaskStatusUpdateEvent,
TextPart,
UnsupportedOperationError,
InternalError,
)
from a2a.utils.errors import ServerError

Expand Down Expand Up @@ -188,7 +190,12 @@ async def test_on_cancel_task_not_found(self) -> None:
mock_task_store.get.assert_called_once_with('nonexistent_id')
mock_agent_executor.cancel.assert_not_called()

async def test_on_message_new_message_success(self) -> None:
@patch(
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
)
async def test_on_message_new_message_success(
self, _mock_builder_build: AsyncMock
) -> None:
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
request_handler = DefaultRequestHandler(
Expand All @@ -199,6 +206,14 @@ async def test_on_message_new_message_success(self) -> None:
mock_task_store.get.return_value = mock_task
mock_agent_executor.execute.return_value = None

_mock_builder_build.return_value = RequestContext(
request=MagicMock(),
task_id='task_123',
context_id='session-xyz',
task=None,
related_tasks=None,
)

async def streaming_coro():
yield mock_task

Expand Down Expand Up @@ -284,15 +299,28 @@ async def streaming_coro():
assert response.root.error == UnsupportedOperationError() # type: ignore
mock_agent_executor.execute.assert_called_once()

async def test_on_message_stream_new_message_success(self) -> None:
@patch(
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
)
async def test_on_message_stream_new_message_success(
self, _mock_builder_build: AsyncMock
) -> None:
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
request_handler = DefaultRequestHandler(
mock_agent_executor, mock_task_store
)
self.mock_agent_card.capabilities = AgentCapabilities(streaming=True)

self.mock_agent_card.capabilities = AgentCapabilities(streaming=True)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
_mock_builder_build.return_value = RequestContext(
request=MagicMock(),
task_id='task_123',
context_id='session-xyz',
task=None,
related_tasks=None,
)

events: list[Any] = [
Task(**MINIMAL_TASK),
TaskArtifactUpdateEvent(
Expand Down Expand Up @@ -467,8 +495,11 @@ async def test_get_push_notification_success(self) -> None:
)
assert get_response.root.result == task_push_config # type: ignore

@patch(
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
)
async def test_on_message_stream_new_message_send_push_notification_success(
self,
self, _mock_builder_build: AsyncMock
) -> None:
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
Expand All @@ -480,6 +511,13 @@ async def test_on_message_stream_new_message_send_push_notification_success(
self.mock_agent_card.capabilities = AgentCapabilities(
streaming=True, pushNotifications=True
)
_mock_builder_build.return_value = RequestContext(
request=MagicMock(),
task_id='task_123',
context_id='session-xyz',
task=None,
related_tasks=None,
)

handler = JSONRPCHandler(self.mock_agent_card, request_handler)
events: list[Any] = [
Expand Down Expand Up @@ -738,7 +776,8 @@ async def test_on_get_push_notification_no_push_notifier(self) -> None:

# Assert
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertEqual(response.root.error, UnsupportedOperationError())
self.assertEqual(response.root.error, UnsupportedOperationError()) # type: ignore


async def test_on_set_push_notification_no_push_notifier(self) -> None:
"""Test set_push_notification with no push notifier configured."""
Expand Down Expand Up @@ -771,7 +810,8 @@ async def test_on_set_push_notification_no_push_notifier(self) -> None:

# Assert
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertEqual(response.root.error, UnsupportedOperationError())
self.assertEqual(response.root.error, UnsupportedOperationError()) # type: ignore


async def test_on_message_send_internal_error(self) -> None:
"""Test on_message_send with an internal error."""
Expand Down Expand Up @@ -800,7 +840,8 @@ async def raise_server_error(*args, **kwargs):

# Assert
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertIsInstance(response.root.error, InternalError)
self.assertIsInstance(response.root.error, InternalError) # type: ignore


async def test_on_message_stream_internal_error(self) -> None:
"""Test on_message_send_stream with an internal error."""
Expand Down Expand Up @@ -906,3 +947,66 @@ async def consume_raises_error(*args, **kwargs):
# Assert
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertEqual(response.root.error, UnsupportedOperationError())

async def test_on_message_send_task_id_mismatch(self) -> None:
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
request_handler = DefaultRequestHandler(
mock_agent_executor, mock_task_store
)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
mock_task = Task(**MINIMAL_TASK)
mock_task_store.get.return_value = mock_task
mock_agent_executor.execute.return_value = None

async def streaming_coro():
yield mock_task

with patch(
'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all',
return_value=streaming_coro(),
):
request = SendMessageRequest(
id='1',
params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)),
)
response = await handler.on_message_send(request)
assert mock_agent_executor.execute.call_count == 1
self.assertIsInstance(response.root, JSONRPCErrorResponse)
self.assertIsInstance(response.root.error, InternalError) # type: ignore

async def test_on_message_stream_task_id_mismatch(self) -> None:
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
request_handler = DefaultRequestHandler(
mock_agent_executor, mock_task_store
)

self.mock_agent_card.capabilities = AgentCapabilities(streaming=True)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
events: list[Any] = [Task(**MINIMAL_TASK)]

async def streaming_coro():
for event in events:
yield event

with patch(
'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all',
return_value=streaming_coro(),
):
mock_task_store.get.return_value = None
mock_agent_executor.execute.return_value = None
request = SendStreamingMessageRequest(
id='1',
params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)),
)
response = handler.on_message_send_stream(request)
assert isinstance(response, AsyncGenerator)
collected_events: list[Any] = []
async for event in response:
collected_events.append(event)
assert len(collected_events) == 1
self.assertIsInstance(
collected_events[0].root, JSONRPCErrorResponse
)
self.assertIsInstance(collected_events[0].root.error, InternalError)