8
8
import httpx
9
9
import pytest
10
10
11
- from a2a .server .agent_execution import AgentExecutor
11
+
12
+ from a2a .server .agent_execution import AgentExecutor , RequestContext
12
13
from a2a .server .agent_execution .request_context_builder import (
13
14
RequestContextBuilder ,
14
15
)
59
60
TaskStatusUpdateEvent ,
60
61
TextPart ,
61
62
UnsupportedOperationError ,
63
+ InternalError ,
62
64
)
63
65
from a2a .utils .errors import ServerError
64
66
@@ -188,7 +190,12 @@ async def test_on_cancel_task_not_found(self) -> None:
188
190
mock_task_store .get .assert_called_once_with ('nonexistent_id' )
189
191
mock_agent_executor .cancel .assert_not_called ()
190
192
191
- async def test_on_message_new_message_success (self ) -> None :
193
+ @patch (
194
+ 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
195
+ )
196
+ async def test_on_message_new_message_success (
197
+ self , _mock_builder_build : AsyncMock
198
+ ) -> None :
192
199
mock_agent_executor = AsyncMock (spec = AgentExecutor )
193
200
mock_task_store = AsyncMock (spec = TaskStore )
194
201
request_handler = DefaultRequestHandler (
@@ -199,6 +206,14 @@ async def test_on_message_new_message_success(self) -> None:
199
206
mock_task_store .get .return_value = mock_task
200
207
mock_agent_executor .execute .return_value = None
201
208
209
+ _mock_builder_build .return_value = RequestContext (
210
+ request = MagicMock (),
211
+ task_id = 'task_123' ,
212
+ context_id = 'session-xyz' ,
213
+ task = None ,
214
+ related_tasks = None ,
215
+ )
216
+
202
217
async def streaming_coro ():
203
218
yield mock_task
204
219
@@ -284,15 +299,28 @@ async def streaming_coro():
284
299
assert response .root .error == UnsupportedOperationError () # type: ignore
285
300
mock_agent_executor .execute .assert_called_once ()
286
301
287
- async def test_on_message_stream_new_message_success (self ) -> None :
302
+ @patch (
303
+ 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
304
+ )
305
+ async def test_on_message_stream_new_message_success (
306
+ self , _mock_builder_build : AsyncMock
307
+ ) -> None :
288
308
mock_agent_executor = AsyncMock (spec = AgentExecutor )
289
309
mock_task_store = AsyncMock (spec = TaskStore )
290
310
request_handler = DefaultRequestHandler (
291
311
mock_agent_executor , mock_task_store
292
312
)
293
- self .mock_agent_card .capabilities = AgentCapabilities (streaming = True )
294
313
314
+ self .mock_agent_card .capabilities = AgentCapabilities (streaming = True )
295
315
handler = JSONRPCHandler (self .mock_agent_card , request_handler )
316
+ _mock_builder_build .return_value = RequestContext (
317
+ request = MagicMock (),
318
+ task_id = 'task_123' ,
319
+ context_id = 'session-xyz' ,
320
+ task = None ,
321
+ related_tasks = None ,
322
+ )
323
+
296
324
events : list [Any ] = [
297
325
Task (** MINIMAL_TASK ),
298
326
TaskArtifactUpdateEvent (
@@ -467,8 +495,11 @@ async def test_get_push_notification_success(self) -> None:
467
495
)
468
496
assert get_response .root .result == task_push_config # type: ignore
469
497
498
+ @patch (
499
+ 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
500
+ )
470
501
async def test_on_message_stream_new_message_send_push_notification_success (
471
- self ,
502
+ self , _mock_builder_build : AsyncMock
472
503
) -> None :
473
504
mock_agent_executor = AsyncMock (spec = AgentExecutor )
474
505
mock_task_store = AsyncMock (spec = TaskStore )
@@ -480,6 +511,13 @@ async def test_on_message_stream_new_message_send_push_notification_success(
480
511
self .mock_agent_card .capabilities = AgentCapabilities (
481
512
streaming = True , pushNotifications = True
482
513
)
514
+ _mock_builder_build .return_value = RequestContext (
515
+ request = MagicMock (),
516
+ task_id = 'task_123' ,
517
+ context_id = 'session-xyz' ,
518
+ task = None ,
519
+ related_tasks = None ,
520
+ )
483
521
484
522
handler = JSONRPCHandler (self .mock_agent_card , request_handler )
485
523
events : list [Any ] = [
@@ -738,7 +776,8 @@ async def test_on_get_push_notification_no_push_notifier(self) -> None:
738
776
739
777
# Assert
740
778
self .assertIsInstance (response .root , JSONRPCErrorResponse )
741
- self .assertEqual (response .root .error , UnsupportedOperationError ())
779
+ self .assertEqual (response .root .error , UnsupportedOperationError ()) # type: ignore
780
+
742
781
743
782
async def test_on_set_push_notification_no_push_notifier (self ) -> None :
744
783
"""Test set_push_notification with no push notifier configured."""
@@ -771,7 +810,8 @@ async def test_on_set_push_notification_no_push_notifier(self) -> None:
771
810
772
811
# Assert
773
812
self .assertIsInstance (response .root , JSONRPCErrorResponse )
774
- self .assertEqual (response .root .error , UnsupportedOperationError ())
813
+ self .assertEqual (response .root .error , UnsupportedOperationError ()) # type: ignore
814
+
775
815
776
816
async def test_on_message_send_internal_error (self ) -> None :
777
817
"""Test on_message_send with an internal error."""
@@ -800,7 +840,8 @@ async def raise_server_error(*args, **kwargs):
800
840
801
841
# Assert
802
842
self .assertIsInstance (response .root , JSONRPCErrorResponse )
803
- self .assertIsInstance (response .root .error , InternalError )
843
+ self .assertIsInstance (response .root .error , InternalError ) # type: ignore
844
+
804
845
805
846
async def test_on_message_stream_internal_error (self ) -> None :
806
847
"""Test on_message_send_stream with an internal error."""
@@ -906,3 +947,66 @@ async def consume_raises_error(*args, **kwargs):
906
947
# Assert
907
948
self .assertIsInstance (response .root , JSONRPCErrorResponse )
908
949
self .assertEqual (response .root .error , UnsupportedOperationError ())
950
+
951
+ async def test_on_message_send_task_id_mismatch (self ) -> None :
952
+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
953
+ mock_task_store = AsyncMock (spec = TaskStore )
954
+ request_handler = DefaultRequestHandler (
955
+ mock_agent_executor , mock_task_store
956
+ )
957
+ handler = JSONRPCHandler (self .mock_agent_card , request_handler )
958
+ mock_task = Task (** MINIMAL_TASK )
959
+ mock_task_store .get .return_value = mock_task
960
+ mock_agent_executor .execute .return_value = None
961
+
962
+ async def streaming_coro ():
963
+ yield mock_task
964
+
965
+ with patch (
966
+ 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all' ,
967
+ return_value = streaming_coro (),
968
+ ):
969
+ request = SendMessageRequest (
970
+ id = '1' ,
971
+ params = MessageSendParams (message = Message (** MESSAGE_PAYLOAD )),
972
+ )
973
+ response = await handler .on_message_send (request )
974
+ assert mock_agent_executor .execute .call_count == 1
975
+ self .assertIsInstance (response .root , JSONRPCErrorResponse )
976
+ self .assertIsInstance (response .root .error , InternalError ) # type: ignore
977
+
978
+ async def test_on_message_stream_task_id_mismatch (self ) -> None :
979
+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
980
+ mock_task_store = AsyncMock (spec = TaskStore )
981
+ request_handler = DefaultRequestHandler (
982
+ mock_agent_executor , mock_task_store
983
+ )
984
+
985
+ self .mock_agent_card .capabilities = AgentCapabilities (streaming = True )
986
+ handler = JSONRPCHandler (self .mock_agent_card , request_handler )
987
+ events : list [Any ] = [Task (** MINIMAL_TASK )]
988
+
989
+ async def streaming_coro ():
990
+ for event in events :
991
+ yield event
992
+
993
+ with patch (
994
+ 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all' ,
995
+ return_value = streaming_coro (),
996
+ ):
997
+ mock_task_store .get .return_value = None
998
+ mock_agent_executor .execute .return_value = None
999
+ request = SendStreamingMessageRequest (
1000
+ id = '1' ,
1001
+ params = MessageSendParams (message = Message (** MESSAGE_PAYLOAD )),
1002
+ )
1003
+ response = handler .on_message_send_stream (request )
1004
+ assert isinstance (response , AsyncGenerator )
1005
+ collected_events : list [Any ] = []
1006
+ async for event in response :
1007
+ collected_events .append (event )
1008
+ assert len (collected_events ) == 1
1009
+ self .assertIsInstance (
1010
+ collected_events [0 ].root , JSONRPCErrorResponse
1011
+ )
1012
+ self .assertIsInstance (collected_events [0 ].root .error , InternalError )
0 commit comments