Skip to content

Commit 73820b1

Browse files
authored
fix: Handle propagating agent exceptions (#20)
* fix: Fix mypy errors and enable streaming for hw example * fix:Handle propagating agent exceptions * fix:Handle propagating agent exceptions * Add comment for queue read timeout
1 parent eacdaa1 commit 73820b1

File tree

3 files changed

+23
-49
lines changed

3 files changed

+23
-49
lines changed

src/a2a/server/agent_execution/base_agent_executor.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

src/a2a/server/events/event_consumer.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class EventConsumer:
2222

2323
def __init__(self, queue: EventQueue):
2424
self.queue = queue
25+
self._timeout = 0.5
26+
self._exception: BaseException | None = None
2527
logger.debug('EventConsumer initialized')
2628

2729
async def consume_one(self) -> Event:
@@ -45,8 +47,15 @@ async def consume_all(self) -> AsyncGenerator[Event]:
4547
"""Consume all the generated streaming events from the agent."""
4648
logger.debug('Starting to consume all events from the queue.')
4749
while True:
50+
if self._exception:
51+
raise self._exception
4852
try:
49-
event = await self.queue.dequeue_event()
53+
# We use a timeout when waiting for an event from the queue.
54+
# This is required because it allows the loop to check if
55+
# `self._exception` has been set by the `agent_task_callback`.
56+
# Without the timeout, loop might hang indefinitely if no events are
57+
# enqueued by the agent and the agent simply threw an exception
58+
event = await asyncio.wait_for(self.queue.dequeue_event(), timeout=self._timeout)
5059
logger.debug(
5160
f'Dequeued event of type: {type(event)} in consume_all.'
5261
)
@@ -74,5 +83,16 @@ async def consume_all(self) -> AsyncGenerator[Event]:
7483
logger.debug('Stopping event consumption in consume_all.')
7584
self.queue.close()
7685
break
86+
except asyncio.TimeoutError:
87+
# continue polling until there is a final event
88+
continue
7789
except asyncio.QueueShutDown:
7890
break
91+
92+
93+
94+
95+
96+
def agent_task_callback(self, agent_task: asyncio.Task[None]):
97+
if agent_task.exception() is not None:
98+
self._exception = agent_task.exception()

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ async def on_message_send(
138138
await self._register_producer(task_id, producer_task)
139139

140140
consumer = EventConsumer(queue)
141+
producer_task.add_done_callback(consumer.agent_task_callback)
141142

142143
interrupted = False
143144
try:
@@ -192,6 +193,7 @@ async def on_message_send_stream(
192193

193194
try:
194195
consumer = EventConsumer(queue)
196+
producer_task.add_done_callback(consumer.agent_task_callback)
195197
async for event in result_aggregator.consume_and_emit(consumer):
196198
# Now we know we have a Task, register the queue
197199
if isinstance(event, Task):

0 commit comments

Comments
 (0)