Skip to content

fix: Throw exception for task_id mismatches #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 21, 2025
6 changes: 2 additions & 4 deletions examples/langgraph/agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from agent import CurrencyAgent # type: ignore[import-untyped]
from typing_extensions import override
from agent import CurrencyAgent # type: ignore[import-untyped]

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events.event_queue import EventQueue
from a2a.types import (
Expand All @@ -17,7 +17,6 @@ class CurrencyAgentExecutor(AgentExecutor):
def __init__(self):
self.agent = CurrencyAgent()

@override
async def execute(
self,
context: RequestContext,
Expand Down Expand Up @@ -89,7 +88,6 @@ async def execute(
)
)

@override
async def cancel(
self, context: RequestContext, event_queue: EventQueue
) -> None:
Expand Down
34 changes: 21 additions & 13 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
EventQueue,
InMemoryQueueManager,
QueueManager,
TaskQueueExists,
)
from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.server.tasks import (
Expand Down Expand Up @@ -212,6 +211,15 @@ async def on_message_send(
) = await result_aggregator.consume_and_break_on_interrupt(consumer)
if not result:
raise ServerError(error=InternalError())

if isinstance(result, Task) and task_id != result.id:
logger.error(
f'Agent generated task_id={result.id} does not match the RequestContext task_id={task_id}.'
)
raise ServerError(
InternalError(message='Task ID mismatch in agent response')
)

finally:
if interrupted:
# TODO: Track this disconnected cleanup task.
Expand Down Expand Up @@ -278,27 +286,27 @@ async def on_message_send_stream(
consumer = EventConsumer(queue)
producer_task.add_done_callback(consumer.agent_task_callback)
async for event in result_aggregator.consume_and_emit(consumer):
if isinstance(event, Task) and task_id != event.id:
logger.warning(
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
)
try:
created_task: Task = event
await self._queue_manager.add(created_task.id, queue)
task_id = created_task.id
except TaskQueueExists:
logging.info(
'Multiple Task objects created in event stream.'
if isinstance(event, Task):
if task_id != event.id:
logger.error(
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
)
raise ServerError(
InternalError(
message='Task ID mismatch in agent response'
)
)

if (
self._push_notifier
and params.configuration
and params.configuration.pushNotificationConfig
):
await self._push_notifier.set_info(
created_task.id,
task_id,
params.configuration.pushNotificationConfig,
)

if self._push_notifier and task_id:
latest_task = await result_aggregator.current_result
if isinstance(latest_task, Task):
Expand Down
Loading