Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,19 @@ async def push_notification_callback() -> None:
(
result,
interrupted_or_non_blocking,
bg_consume_task,
) = await result_aggregator.consume_and_break_on_interrupt(
consumer,
blocking=blocking,
event_callback=push_notification_callback,
)

if bg_consume_task is not None:
bg_consume_task.set_name(
f'continue_consuming:{task_id}'
)
self._track_background_task(bg_consume_task)

except Exception:
logger.exception('Agent execution failed')
producer_task.cancel()
Expand Down
17 changes: 12 additions & 5 deletions src/a2a/server/tasks/result_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def consume_and_break_on_interrupt(
consumer: EventConsumer,
blocking: bool = True,
event_callback: Callable[[], Awaitable[None]] | None = None,
) -> tuple[Task | Message | None, bool]:
) -> tuple[Task | Message | None, bool, asyncio.Task | None]:
"""Processes the event stream until completion or an interruptable state is encountered.

If `blocking` is False, it returns after the first event that creates a Task or Message.
Expand All @@ -119,16 +119,23 @@ async def consume_and_break_on_interrupt(
A tuple containing:
- The current aggregated result (`Task` or `Message`) at the point of completion or interruption.
- A boolean indicating whether the consumption was interrupted (`True`) or completed naturally (`False`).
- The background ``asyncio.Task`` that continues consuming events
after an interruption, or ``None`` when no background work was
spawned. **Callers must hold a strong reference** to this task
(e.g. in a ``set``) to prevent the garbage collector from
collecting it before it finishes — the event loop only keeps
weak references to tasks.

Raises:
BaseException: If the `EventConsumer` raises an exception during consumption.
"""
event_stream = consumer.consume_all()
interrupted = False
bg_task: asyncio.Task | None = None
async for event in event_stream:
if isinstance(event, Message):
self._message = event
return event, False
return event, False, None
await self.task_manager.process(event)

should_interrupt = False
Expand Down Expand Up @@ -158,13 +165,13 @@ async def consume_and_break_on_interrupt(

if should_interrupt:
# Continue consuming the rest of the events in the background.
# TODO: We should track all outstanding tasks to ensure they eventually complete.
asyncio.create_task( # noqa: RUF006
# The caller is responsible for tracking this task to prevent GC.
bg_task = asyncio.create_task(
self._continue_consuming(event_stream, event_callback)
)
interrupted = True
break
return await self.task_manager.get_task(), interrupted
return await self.task_manager.get_task(), interrupted, bg_task

async def _continue_consuming(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ async def test_on_message_send_with_push_notification():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
final_task_result,
False,
None,
)

# Mock the current_result property to return the final task result
Expand Down Expand Up @@ -520,6 +521,7 @@ async def test_on_message_send_with_push_notification_in_non_blocking_request():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
initial_task,
True, # interrupted = True for non-blocking
MagicMock(spec=asyncio.Task), # background task
)

# Mock the current_result property to return the final task
Expand All @@ -540,7 +542,7 @@ async def mock_consume_and_break_on_interrupt(
nonlocal event_callback_passed, event_callback_received
event_callback_passed = event_callback is not None
event_callback_received = event_callback
return initial_task, True # interrupted = True for non-blocking
return initial_task, True, MagicMock(spec=asyncio.Task) # interrupted = True for non-blocking

mock_result_aggregator_instance.consume_and_break_on_interrupt = (
mock_consume_and_break_on_interrupt
Expand Down Expand Up @@ -631,6 +633,7 @@ async def test_on_message_send_with_push_notification_no_existing_Task():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
final_task_result,
False,
None,
)

# Mock the current_result property to return the final task result
Expand Down Expand Up @@ -689,6 +692,7 @@ async def test_on_message_send_no_result_from_aggregator():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
None,
False,
None,
)

from a2a.utils.errors import ServerError # Local import
Expand Down Expand Up @@ -740,6 +744,7 @@ async def test_on_message_send_task_id_mismatch():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
mismatched_task,
False,
None,
)

from a2a.utils.errors import ServerError # Local import
Expand Down Expand Up @@ -950,6 +955,7 @@ async def test_on_message_send_interrupted_flow():
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
interrupt_task_result,
True,
MagicMock(spec=asyncio.Task), # background task
) # Interrupted = True

# Patch asyncio.create_task to verify _cleanup_producer is scheduled
Expand Down
14 changes: 13 additions & 1 deletion tests/server/tasks/test_result_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,14 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

self.assertEqual(result, sample_message)
self.assertFalse(interrupted)
self.assertIsNone(bg_task)
self.mock_task_manager.process.assert_not_called() # Process is not called for the Message if returned directly
# _continue_consuming should not be called if it's a message interrupt
# and no auth_required state.
Expand All @@ -260,17 +262,21 @@ async def mock_consume_generator():

# Mock _continue_consuming to check if it's called by create_task
self.aggregator._continue_consuming = AsyncMock()
sentinel_task = asyncio.ensure_future(asyncio.sleep(0))
mock_create_task.return_value = sentinel_task
mock_create_task.side_effect = lambda coro: asyncio.ensure_future(coro)

(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

self.assertEqual(result, auth_task)
self.assertTrue(interrupted)
self.assertIsNotNone(bg_task)
self.mock_task_manager.process.assert_called_once_with(auth_task)
mock_create_task.assert_called_once() # Check that create_task was called
# self.aggregator._continue_consuming is an AsyncMock.
Expand Down Expand Up @@ -317,12 +323,14 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

self.assertEqual(result, current_task_state_after_update)
self.assertTrue(interrupted)
self.assertIsNotNone(bg_task)
self.mock_task_manager.process.assert_called_once_with(
auth_status_update
)
Expand Down Expand Up @@ -353,13 +361,15 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

# If the first event is a Message, it's returned directly.
self.assertEqual(result, event1)
self.assertFalse(interrupted)
self.assertIsNone(bg_task)
# process() is NOT called for the Message if it's the one causing the return
self.mock_task_manager.process.assert_not_called()
self.mock_task_manager.get_task.assert_not_called()
Expand Down Expand Up @@ -415,12 +425,14 @@ async def mock_consume_generator():
(
result,
interrupted,
bg_task,
) = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer, blocking=False
)

self.assertEqual(result, first_event)
self.assertTrue(interrupted)
self.assertIsNotNone(bg_task)
self.mock_task_manager.process.assert_called_once_with(first_event)
mock_create_task.assert_called_once()
# The background task should be created with the remaining stream
Expand Down Expand Up @@ -468,7 +480,7 @@ async def initial_consume_generator():
mock_create_task.side_effect = lambda coro: asyncio.ensure_future(coro)

# Call the main method that triggers _continue_consuming via create_task
_, _ = await self.aggregator.consume_and_break_on_interrupt(
_, _, _ = await self.aggregator.consume_and_break_on_interrupt(
self.mock_event_consumer
)

Expand Down
Loading