11import asyncio
2- import contextlib
32import logging
3+
44from collections .abc import AsyncGenerator
55from typing import cast
66
1010 EventConsumer ,
1111 EventQueue ,
1212 InMemoryQueueManager ,
13- NoTaskQueue ,
1413 QueueManager ,
1514 TaskQueueExists ,
1615)
1716from a2a .server .request_handlers .request_handler import RequestHandler
18- from a2a .server .tasks import ResultAggregator , TaskManager , TaskStore
17+ from a2a .server .tasks import (
18+ PushNotifier ,
19+ ResultAggregator ,
20+ TaskManager ,
21+ TaskStore ,
22+ )
1923from a2a .types import (
2024 InternalError ,
2125 Message ,
2933)
3034from a2a .utils .errors import ServerError
3135
36+
3237logger = logging .getLogger (__name__ )
3338
3439
@@ -42,10 +47,12 @@ def __init__(
4247 agent_executor : AgentExecutor ,
4348 task_store : TaskStore ,
4449 queue_manager : QueueManager | None = None ,
50+ push_notifier : PushNotifier | None = None ,
4551 ) -> None :
4652 self .agent_executor = agent_executor
4753 self .task_store = task_store
4854 self ._queue_manager = queue_manager or InMemoryQueueManager ()
55+ self ._push_notifier = push_notifier
4956 # TODO: Likely want an interface for managing this, like AgentExecutionManager.
5057 self ._running_agents = {}
5158 self ._running_agents_lock = asyncio .Lock ()
@@ -116,6 +123,15 @@ async def on_message_send(
116123 task : Task | None = await task_manager .get_task ()
117124 if task :
118125 task = task_manager .update_with_message (params .message , task )
126+ if (
127+ self ._push_notifier
128+ and params .configuration
129+ and params .configuration .pushNotificationConfig
130+ and not params .configuration .blocking
131+ ):
132+ await self ._push_notifier .set_info (
133+ task .id , params .configuration .pushNotificationConfig
134+ )
119135 request_context = RequestContext (
120136 params ,
121137 task .id if task else None ,
@@ -173,6 +189,16 @@ async def on_message_send_stream(
173189 if task :
174190 task = task_manager .update_with_message (params .message , task )
175191
192+ if (
193+ self ._push_notifier
194+ and params .configuration
195+ and params .configuration .pushNotificationConfig
196+ ):
197+ await self ._push_notifier .set_info (
198+ task .id , params .configuration .pushNotificationConfig
199+ )
200+ else :
201+ queue = EventQueue ()
176202 result_aggregator = ResultAggregator (task_manager )
177203 request_context = RequestContext (
178204 params ,
@@ -196,12 +222,26 @@ async def on_message_send_stream(
196222 # Now we know we have a Task, register the queue
197223 if isinstance (event , Task ):
198224 try :
199- await self ._queue_manager .add (event .id , queue )
200- task_id = event .id
225+ created_task : Task = event
226+ await self ._queue_manager .add (created_task .id , queue )
227+ task_id = created_task .id
201228 except TaskQueueExists :
202229 logging .info (
203230 'Multiple Task objects created in event stream.'
204231 )
232+ if (
233+ self ._push_notifier
234+ and params .configuration
235+ and params .configuration .pushNotificationConfig
236+ ):
237+ await self ._push_notifier .set_info (
238+ created_task .id ,
239+ params .configuration .pushNotificationConfig ,
240+ )
241+ if self ._push_notifier and task_id :
242+ latest_task = await result_aggregator .current_result
243+ if isinstance (latest_task , Task ):
244+ await self ._push_notifier .send_notification (latest_task )
205245 yield event
206246 finally :
207247 await self ._cleanup_producer (producer_task , task_id )
@@ -220,13 +260,38 @@ async def on_set_task_push_notification_config(
220260 self , params : TaskPushNotificationConfig
221261 ) -> TaskPushNotificationConfig :
222262 """Default handler for 'tasks/pushNotificationConfig/set'."""
223- raise ServerError (error = UnsupportedOperationError ())
263+ if not self ._push_notifier :
264+ raise ServerError (error = UnsupportedOperationError ())
265+
266+ task : Task | None = await self .task_store .get (params .taskId )
267+ if not task :
268+ raise ServerError (error = TaskNotFoundError ())
269+
270+ await self ._push_notifier .set_info (
271+ params .taskId ,
272+ params .pushNotificationConfig ,
273+ )
274+
275+ return params
224276
225277 async def on_get_task_push_notification_config (
226278 self , params : TaskIdParams
227279 ) -> TaskPushNotificationConfig :
228280 """Default handler for 'tasks/pushNotificationConfig/get'."""
229- raise ServerError (error = UnsupportedOperationError ())
281+ if not self ._push_notifier :
282+ raise ServerError (error = UnsupportedOperationError ())
283+
284+ task : Task | None = await self .task_store .get (params .id )
285+ if not task :
286+ raise ServerError (error = TaskNotFoundError ())
287+
288+ push_notification_config = await self ._push_notifier .get_info (params .id )
289+ if not push_notification_config :
290+ raise ServerError (error = TaskNotFoundError ())
291+
292+ return TaskPushNotificationConfig (
293+ taskId = params .id , pushNotificationConfig = push_notification_config
294+ )
230295
231296 async def on_resubscribe_to_task (
232297 self , params : TaskIdParams
0 commit comments