@@ -405,6 +405,134 @@ async def get_current_result():
405405 mock_agent_executor .execute .assert_awaited_once ()
406406
407407
408+ @pytest .mark .asyncio
409+ async def test_on_message_send_with_push_notification_in_non_blocking_request ():
410+ """Test that push notification callback is called during background event processing for non-blocking requests."""
411+ mock_task_store = AsyncMock (spec = TaskStore )
412+ mock_push_notification_store = AsyncMock (spec = PushNotificationConfigStore )
413+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
414+ mock_request_context_builder = AsyncMock (spec = RequestContextBuilder )
415+ mock_push_sender = AsyncMock ()
416+
417+ task_id = 'non_blocking_task_1'
418+ context_id = 'non_blocking_ctx_1'
419+
420+ # Create a task that will be returned after the first event
421+ initial_task = create_sample_task (
422+ task_id = task_id , context_id = context_id , status_state = TaskState .working
423+ )
424+
425+ # Create a final task that will be available during background processing
426+ final_task = create_sample_task (
427+ task_id = task_id , context_id = context_id , status_state = TaskState .completed
428+ )
429+
430+ mock_task_store .get .return_value = None
431+
432+ # Mock request context
433+ mock_request_context = MagicMock (spec = RequestContext )
434+ mock_request_context .task_id = task_id
435+ mock_request_context .context_id = context_id
436+ mock_request_context_builder .build .return_value = mock_request_context
437+
438+ request_handler = DefaultRequestHandler (
439+ agent_executor = mock_agent_executor ,
440+ task_store = mock_task_store ,
441+ push_config_store = mock_push_notification_store ,
442+ request_context_builder = mock_request_context_builder ,
443+ push_sender = mock_push_sender ,
444+ )
445+
446+ # Configure push notification
447+ push_config = PushNotificationConfig (url = 'http://callback.com/push' )
448+ message_config = MessageSendConfiguration (
449+ push_notification_config = push_config ,
450+ accepted_output_modes = ['text/plain' ],
451+ blocking = False , # Non-blocking request
452+ )
453+ params = MessageSendParams (
454+ message = Message (
455+ role = Role .user ,
456+ message_id = 'msg_non_blocking' ,
457+ parts = [],
458+ task_id = task_id ,
459+ context_id = context_id ,
460+ ),
461+ configuration = message_config ,
462+ )
463+
464+ # Mock ResultAggregator with custom behavior
465+ mock_result_aggregator_instance = AsyncMock (spec = ResultAggregator )
466+
467+ # First call returns the initial task and indicates interruption (non-blocking)
468+ mock_result_aggregator_instance .consume_and_break_on_interrupt .return_value = (
469+ initial_task ,
470+ True , # interrupted = True for non-blocking
471+ )
472+
473+ # Mock the current_result property to return the final task
474+ async def get_current_result ():
475+ return final_task
476+
477+ type(mock_result_aggregator_instance ).current_result = PropertyMock (
478+ return_value = get_current_result ()
479+ )
480+
481+ # Track if the event_callback was passed to consume_and_break_on_interrupt
482+ event_callback_passed = False
483+ event_callback_received = None
484+
485+ async def mock_consume_and_break_on_interrupt (
486+ consumer , blocking = True , event_callback = None
487+ ):
488+ nonlocal event_callback_passed , event_callback_received
489+ event_callback_passed = event_callback is not None
490+ event_callback_received = event_callback
491+ return initial_task , True # interrupted = True for non-blocking
492+
493+ mock_result_aggregator_instance .consume_and_break_on_interrupt = (
494+ mock_consume_and_break_on_interrupt
495+ )
496+
497+ with (
498+ patch (
499+ 'a2a.server.request_handlers.default_request_handler.ResultAggregator' ,
500+ return_value = mock_result_aggregator_instance ,
501+ ),
502+ patch (
503+ 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task' ,
504+ return_value = initial_task ,
505+ ),
506+ patch (
507+ 'a2a.server.request_handlers.default_request_handler.TaskManager.update_with_message' ,
508+ return_value = initial_task ,
509+ ),
510+ ):
511+ # Execute the non-blocking request
512+ result = await request_handler .on_message_send (
513+ params , create_server_call_context ()
514+ )
515+
516+ # Verify the result is the initial task (non-blocking behavior)
517+ assert result == initial_task
518+
519+ # Verify that the event_callback was passed to consume_and_break_on_interrupt
520+ assert event_callback_passed , (
521+ 'event_callback should have been passed to consume_and_break_on_interrupt'
522+ )
523+ assert event_callback_received is not None , (
524+ 'event_callback should not be None'
525+ )
526+
527+ # Verify that the push notification was sent with the final task
528+ mock_push_sender .send_notification .assert_called_with (final_task )
529+
530+ # Verify that the push notification config was stored
531+ mock_push_notification_store .set_info .assert_awaited_once_with (
532+ task_id , push_config
533+ )
534+
535+
408536@pytest .mark .asyncio
409537async def test_on_message_send_with_push_notification_no_existing_Task ():
410538 """Test on_message_send for new task sets push notification info if provided."""
0 commit comments