diff --git a/src/a2a/server/agent_execution/base_agent_executor.py b/src/a2a/server/agent_execution/base_agent_executor.py deleted file mode 100644 index 13f228b1..00000000 --- a/src/a2a/server/agent_execution/base_agent_executor.py +++ /dev/null @@ -1,48 +0,0 @@ -from a2a.server.agent_execution.agent_executor import AgentExecutor -from a2a.server.events.event_queue import EventQueue -from a2a.types import ( - A2AError, - CancelTaskRequest, - SendMessageRequest, - SendStreamingMessageRequest, - Task, - TaskResubscriptionRequest, - UnsupportedOperationError, -) - - -class BaseAgentExecutor(AgentExecutor): - """Base AgentExecutor which returns unsupported operation error.""" - - async def on_message_send( - self, - request: SendMessageRequest, - event_queue: EventQueue, - task: Task | None, - ) -> None: - """Handler for 'message/send' requests.""" - event_queue.enqueue_event(A2AError(UnsupportedOperationError())) - - async def on_message_stream( - self, - request: SendStreamingMessageRequest, - event_queue: EventQueue, - task: Task | None, - ) -> None: - """Handler for 'message/stream' requests.""" - event_queue.enqueue_event(A2AError(UnsupportedOperationError())) - - async def on_cancel( - self, request: CancelTaskRequest, event_queue: EventQueue, task: Task - ) -> None: - """Handler for 'tasks/cancel' requests.""" - event_queue.enqueue_event(A2AError(UnsupportedOperationError())) - - async def on_resubscribe( - self, - request: TaskResubscriptionRequest, - event_queue: EventQueue, - task: Task, - ) -> None: - """Handler for 'tasks/resubscribe' requests.""" - event_queue.enqueue_event(A2AError(UnsupportedOperationError())) diff --git a/src/a2a/server/events/event_consumer.py b/src/a2a/server/events/event_consumer.py index dcfa9d98..e1109402 100644 --- a/src/a2a/server/events/event_consumer.py +++ b/src/a2a/server/events/event_consumer.py @@ -22,6 +22,8 @@ class EventConsumer: def __init__(self, queue: EventQueue): self.queue = queue + self._timeout = 0.5 + self._exception: BaseException | None = None logger.debug('EventConsumer initialized') async def consume_one(self) -> Event: @@ -45,8 +47,15 @@ async def consume_all(self) -> AsyncGenerator[Event]: """Consume all the generated streaming events from the agent.""" logger.debug('Starting to consume all events from the queue.') while True: + if self._exception: + raise self._exception try: - event = await self.queue.dequeue_event() + # We use a timeout when waiting for an event from the queue. + # This is required because it allows the loop to check if + # `self._exception` has been set by the `agent_task_callback`. + # Without the timeout, loop might hang indefinitely if no events are + # enqueued by the agent and the agent simply threw an exception + event = await asyncio.wait_for(self.queue.dequeue_event(), timeout=self._timeout) logger.debug( f'Dequeued event of type: {type(event)} in consume_all.' ) @@ -74,5 +83,16 @@ async def consume_all(self) -> AsyncGenerator[Event]: logger.debug('Stopping event consumption in consume_all.') self.queue.close() break + except asyncio.TimeoutError: + # continue polling until there is a final event + continue except asyncio.QueueShutDown: break + + + + + + def agent_task_callback(self, agent_task: asyncio.Task[None]): + if agent_task.exception() is not None: + self._exception = agent_task.exception() \ No newline at end of file diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index f656e414..4107cdf2 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -138,6 +138,7 @@ async def on_message_send( await self._register_producer(task_id, producer_task) consumer = EventConsumer(queue) + producer_task.add_done_callback(consumer.agent_task_callback) interrupted = False try: @@ -192,6 +193,7 @@ async def on_message_send_stream( try: consumer = EventConsumer(queue) + producer_task.add_done_callback(consumer.agent_task_callback) async for event in result_aggregator.consume_and_emit(consumer): # Now we know we have a Task, register the queue if isinstance(event, Task):