Skip to content

Commit 95fe755

Browse files
kthota-gholtskinnerrelease-please[bot]mindpower
authored andcommitted
fix: Throw exception for task_id mismatches (google-a2a#70)
* fix: Throw exception for task_id mismatches * Add tests * ci: Create GitHub Action to generate `types.py` from specification JSON (google-a2a#67) * ci: Remove update-a2a-types.yml workflow * chore: Regenerate types.py from spec (google-a2a#71) * chore(main): release 0.2.3 (google-a2a#68) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> * test: Adding 8 tests for jsonrpc_handler.py and also fix minor waring in test_integration.py (google-a2a#72) * test: Adding 8 tests for jsonrpc_handler.py and also fix minor waring in test_integration.py * test: remove comments * Add tests --------- Co-authored-by: Holt Skinner <[email protected]> Co-authored-by: holtskinner <[email protected]> Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> Co-authored-by: Junjie Bu <[email protected]>
1 parent 7bac8b9 commit 95fe755

File tree

3 files changed

+135
-25
lines changed

3 files changed

+135
-25
lines changed

examples/langgraph/agent_executor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from agent import CurrencyAgent # type: ignore[import-untyped]
2-
from typing_extensions import override
1+
from agent import CurrencyAgent # type: ignore[import-untyped]
2+
33
from a2a.server.agent_execution import AgentExecutor, RequestContext
44
from a2a.server.events.event_queue import EventQueue
55
from a2a.types import (
@@ -17,7 +17,6 @@ class CurrencyAgentExecutor(AgentExecutor):
1717
def __init__(self):
1818
self.agent = CurrencyAgent()
1919

20-
@override
2120
async def execute(
2221
self,
2322
context: RequestContext,
@@ -89,7 +88,6 @@ async def execute(
8988
)
9089
)
9190

92-
@override
9391
async def cancel(
9492
self, context: RequestContext, event_queue: EventQueue
9593
) -> None:

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
EventQueue,
1717
InMemoryQueueManager,
1818
QueueManager,
19-
TaskQueueExists,
2019
)
2120
from a2a.server.request_handlers.request_handler import RequestHandler
2221
from a2a.server.tasks import (
@@ -212,6 +211,15 @@ async def on_message_send(
212211
) = await result_aggregator.consume_and_break_on_interrupt(consumer)
213212
if not result:
214213
raise ServerError(error=InternalError())
214+
215+
if isinstance(result, Task) and task_id != result.id:
216+
logger.error(
217+
f'Agent generated task_id={result.id} does not match the RequestContext task_id={task_id}.'
218+
)
219+
raise ServerError(
220+
InternalError(message='Task ID mismatch in agent response')
221+
)
222+
215223
finally:
216224
if interrupted:
217225
# TODO: Track this disconnected cleanup task.
@@ -278,27 +286,27 @@ async def on_message_send_stream(
278286
consumer = EventConsumer(queue)
279287
producer_task.add_done_callback(consumer.agent_task_callback)
280288
async for event in result_aggregator.consume_and_emit(consumer):
281-
if isinstance(event, Task) and task_id != event.id:
282-
logger.warning(
283-
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
284-
)
285-
try:
286-
created_task: Task = event
287-
await self._queue_manager.add(created_task.id, queue)
288-
task_id = created_task.id
289-
except TaskQueueExists:
290-
logging.info(
291-
'Multiple Task objects created in event stream.'
289+
if isinstance(event, Task):
290+
if task_id != event.id:
291+
logger.error(
292+
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
292293
)
294+
raise ServerError(
295+
InternalError(
296+
message='Task ID mismatch in agent response'
297+
)
298+
)
299+
293300
if (
294301
self._push_notifier
295302
and params.configuration
296303
and params.configuration.pushNotificationConfig
297304
):
298305
await self._push_notifier.set_info(
299-
created_task.id,
306+
task_id,
300307
params.configuration.pushNotificationConfig,
301308
)
309+
302310
if self._push_notifier and task_id:
303311
latest_task = await result_aggregator.current_result
304312
if isinstance(latest_task, Task):

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 112 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import httpx
99
import pytest
1010

11-
from a2a.server.agent_execution import AgentExecutor
11+
12+
from a2a.server.agent_execution import AgentExecutor, RequestContext
1213
from a2a.server.agent_execution.request_context_builder import (
1314
RequestContextBuilder,
1415
)
@@ -59,6 +60,7 @@
5960
TaskStatusUpdateEvent,
6061
TextPart,
6162
UnsupportedOperationError,
63+
InternalError,
6264
)
6365
from a2a.utils.errors import ServerError
6466

@@ -188,7 +190,12 @@ async def test_on_cancel_task_not_found(self) -> None:
188190
mock_task_store.get.assert_called_once_with('nonexistent_id')
189191
mock_agent_executor.cancel.assert_not_called()
190192

191-
async def test_on_message_new_message_success(self) -> None:
193+
@patch(
194+
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
195+
)
196+
async def test_on_message_new_message_success(
197+
self, _mock_builder_build: AsyncMock
198+
) -> None:
192199
mock_agent_executor = AsyncMock(spec=AgentExecutor)
193200
mock_task_store = AsyncMock(spec=TaskStore)
194201
request_handler = DefaultRequestHandler(
@@ -199,6 +206,14 @@ async def test_on_message_new_message_success(self) -> None:
199206
mock_task_store.get.return_value = mock_task
200207
mock_agent_executor.execute.return_value = None
201208

209+
_mock_builder_build.return_value = RequestContext(
210+
request=MagicMock(),
211+
task_id='task_123',
212+
context_id='session-xyz',
213+
task=None,
214+
related_tasks=None,
215+
)
216+
202217
async def streaming_coro():
203218
yield mock_task
204219

@@ -284,15 +299,28 @@ async def streaming_coro():
284299
assert response.root.error == UnsupportedOperationError() # type: ignore
285300
mock_agent_executor.execute.assert_called_once()
286301

287-
async def test_on_message_stream_new_message_success(self) -> None:
302+
@patch(
303+
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
304+
)
305+
async def test_on_message_stream_new_message_success(
306+
self, _mock_builder_build: AsyncMock
307+
) -> None:
288308
mock_agent_executor = AsyncMock(spec=AgentExecutor)
289309
mock_task_store = AsyncMock(spec=TaskStore)
290310
request_handler = DefaultRequestHandler(
291311
mock_agent_executor, mock_task_store
292312
)
293-
self.mock_agent_card.capabilities = AgentCapabilities(streaming=True)
294313

314+
self.mock_agent_card.capabilities = AgentCapabilities(streaming=True)
295315
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
316+
_mock_builder_build.return_value = RequestContext(
317+
request=MagicMock(),
318+
task_id='task_123',
319+
context_id='session-xyz',
320+
task=None,
321+
related_tasks=None,
322+
)
323+
296324
events: list[Any] = [
297325
Task(**MINIMAL_TASK),
298326
TaskArtifactUpdateEvent(
@@ -467,8 +495,11 @@ async def test_get_push_notification_success(self) -> None:
467495
)
468496
assert get_response.root.result == task_push_config # type: ignore
469497

498+
@patch(
499+
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
500+
)
470501
async def test_on_message_stream_new_message_send_push_notification_success(
471-
self,
502+
self, _mock_builder_build: AsyncMock
472503
) -> None:
473504
mock_agent_executor = AsyncMock(spec=AgentExecutor)
474505
mock_task_store = AsyncMock(spec=TaskStore)
@@ -480,6 +511,13 @@ async def test_on_message_stream_new_message_send_push_notification_success(
480511
self.mock_agent_card.capabilities = AgentCapabilities(
481512
streaming=True, pushNotifications=True
482513
)
514+
_mock_builder_build.return_value = RequestContext(
515+
request=MagicMock(),
516+
task_id='task_123',
517+
context_id='session-xyz',
518+
task=None,
519+
related_tasks=None,
520+
)
483521

484522
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
485523
events: list[Any] = [
@@ -738,7 +776,8 @@ async def test_on_get_push_notification_no_push_notifier(self) -> None:
738776

739777
# Assert
740778
self.assertIsInstance(response.root, JSONRPCErrorResponse)
741-
self.assertEqual(response.root.error, UnsupportedOperationError())
779+
self.assertEqual(response.root.error, UnsupportedOperationError()) # type: ignore
780+
742781

743782
async def test_on_set_push_notification_no_push_notifier(self) -> None:
744783
"""Test set_push_notification with no push notifier configured."""
@@ -771,7 +810,8 @@ async def test_on_set_push_notification_no_push_notifier(self) -> None:
771810

772811
# Assert
773812
self.assertIsInstance(response.root, JSONRPCErrorResponse)
774-
self.assertEqual(response.root.error, UnsupportedOperationError())
813+
self.assertEqual(response.root.error, UnsupportedOperationError()) # type: ignore
814+
775815

776816
async def test_on_message_send_internal_error(self) -> None:
777817
"""Test on_message_send with an internal error."""
@@ -800,7 +840,8 @@ async def raise_server_error(*args, **kwargs):
800840

801841
# Assert
802842
self.assertIsInstance(response.root, JSONRPCErrorResponse)
803-
self.assertIsInstance(response.root.error, InternalError)
843+
self.assertIsInstance(response.root.error, InternalError) # type: ignore
844+
804845

805846
async def test_on_message_stream_internal_error(self) -> None:
806847
"""Test on_message_send_stream with an internal error."""
@@ -906,3 +947,66 @@ async def consume_raises_error(*args, **kwargs):
906947
# Assert
907948
self.assertIsInstance(response.root, JSONRPCErrorResponse)
908949
self.assertEqual(response.root.error, UnsupportedOperationError())
950+
951+
async def test_on_message_send_task_id_mismatch(self) -> None:
952+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
953+
mock_task_store = AsyncMock(spec=TaskStore)
954+
request_handler = DefaultRequestHandler(
955+
mock_agent_executor, mock_task_store
956+
)
957+
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
958+
mock_task = Task(**MINIMAL_TASK)
959+
mock_task_store.get.return_value = mock_task
960+
mock_agent_executor.execute.return_value = None
961+
962+
async def streaming_coro():
963+
yield mock_task
964+
965+
with patch(
966+
'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all',
967+
return_value=streaming_coro(),
968+
):
969+
request = SendMessageRequest(
970+
id='1',
971+
params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)),
972+
)
973+
response = await handler.on_message_send(request)
974+
assert mock_agent_executor.execute.call_count == 1
975+
self.assertIsInstance(response.root, JSONRPCErrorResponse)
976+
self.assertIsInstance(response.root.error, InternalError) # type: ignore
977+
978+
async def test_on_message_stream_task_id_mismatch(self) -> None:
979+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
980+
mock_task_store = AsyncMock(spec=TaskStore)
981+
request_handler = DefaultRequestHandler(
982+
mock_agent_executor, mock_task_store
983+
)
984+
985+
self.mock_agent_card.capabilities = AgentCapabilities(streaming=True)
986+
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
987+
events: list[Any] = [Task(**MINIMAL_TASK)]
988+
989+
async def streaming_coro():
990+
for event in events:
991+
yield event
992+
993+
with patch(
994+
'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all',
995+
return_value=streaming_coro(),
996+
):
997+
mock_task_store.get.return_value = None
998+
mock_agent_executor.execute.return_value = None
999+
request = SendStreamingMessageRequest(
1000+
id='1',
1001+
params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)),
1002+
)
1003+
response = handler.on_message_send_stream(request)
1004+
assert isinstance(response, AsyncGenerator)
1005+
collected_events: list[Any] = []
1006+
async for event in response:
1007+
collected_events.append(event)
1008+
assert len(collected_events) == 1
1009+
self.assertIsInstance(
1010+
collected_events[0].root, JSONRPCErrorResponse
1011+
)
1012+
self.assertIsInstance(collected_events[0].root.error, InternalError)

0 commit comments

Comments
 (0)