1414 TaskQueueExists ,
1515)
1616from a2a .server .request_handlers .request_handler import RequestHandler
17- from a2a .server .tasks import ResultAggregator , TaskManager , TaskStore
17+ from a2a .server .tasks import (
18+ PushNotifier ,
19+ ResultAggregator ,
20+ TaskManager ,
21+ TaskStore ,
22+ )
1823from a2a .types import (
1924 InternalError ,
2025 Message ,
26+ MessageSendConfiguration ,
2127 MessageSendParams ,
28+ PushNotificationConfig ,
2229 Task ,
2330 TaskIdParams ,
2431 TaskNotFoundError ,
@@ -44,10 +51,12 @@ def __init__(
4451 agent_executor : AgentExecutor ,
4552 task_store : TaskStore ,
4653 queue_manager : QueueManager | None = None ,
54+ push_notifier : PushNotifier | None = None ,
4755 ) -> None :
4856 self .agent_executor = agent_executor
4957 self .task_store = task_store
5058 self ._queue_manager = queue_manager or InMemoryQueueManager ()
59+ self ._push_notifier = push_notifier
5160 # TODO: Likely want an interface for managing this, like AgentExecutionManager.
5261 self ._running_agents = {}
5362 self ._running_agents_lock = asyncio .Lock ()
@@ -118,6 +127,18 @@ async def on_message_send(
118127 task : Task | None = await task_manager .get_task ()
119128 if task :
120129 task = task_manager .update_with_message (params .message , task )
130+ if self .should_add_push_info (params ):
131+ assert isinstance (self ._push_notifier , PushNotifier )
132+ assert isinstance (
133+ params .configuration , MessageSendConfiguration
134+ )
135+ assert isinstance (
136+ params .configuration .pushNotificationConfig ,
137+ PushNotificationConfig ,
138+ )
139+ await self ._push_notifier .set_info (
140+ task .id , params .configuration .pushNotificationConfig
141+ )
121142 request_context = RequestContext (
122143 params ,
123144 task .id if task else None ,
@@ -176,6 +197,20 @@ async def on_message_send_stream(
176197 if task :
177198 task = task_manager .update_with_message (params .message , task )
178199
200+ if self .should_add_push_info (params ):
201+ assert isinstance (self ._push_notifier , PushNotifier )
202+ assert isinstance (
203+ params .configuration , MessageSendConfiguration
204+ )
205+ assert isinstance (
206+ params .configuration .pushNotificationConfig ,
207+ PushNotificationConfig ,
208+ )
209+ await self ._push_notifier .set_info (
210+ task .id , params .configuration .pushNotificationConfig
211+ )
212+ else :
213+ queue = EventQueue ()
179214 result_aggregator = ResultAggregator (task_manager )
180215 request_context = RequestContext (
181216 params ,
@@ -202,12 +237,26 @@ async def on_message_send_stream(
202237 f'Agent generated task_id={ event .id } does not match the RequestContext task_id={ task_id } .'
203238 )
204239 try :
205- await self ._queue_manager .add (event .id , queue )
206- task_id = event .id
240+ created_task : Task = event
241+ await self ._queue_manager .add (created_task .id , queue )
242+ task_id = created_task .id
207243 except TaskQueueExists :
208244 logging .info (
209245 'Multiple Task objects created in event stream.'
210246 )
247+ if (
248+ self ._push_notifier
249+ and params .configuration
250+ and params .configuration .pushNotificationConfig
251+ ):
252+ await self ._push_notifier .set_info (
253+ created_task .id ,
254+ params .configuration .pushNotificationConfig ,
255+ )
256+ if self ._push_notifier and task_id :
257+ latest_task = await result_aggregator .current_result
258+ if isinstance (latest_task , Task ):
259+ await self ._push_notifier .send_notification (latest_task )
211260 yield event
212261 finally :
213262 await self ._cleanup_producer (producer_task , task_id )
@@ -226,13 +275,38 @@ async def on_set_task_push_notification_config(
226275 self , params : TaskPushNotificationConfig
227276 ) -> TaskPushNotificationConfig :
228277 """Default handler for 'tasks/pushNotificationConfig/set'."""
229- raise ServerError (error = UnsupportedOperationError ())
278+ if not self ._push_notifier :
279+ raise ServerError (error = UnsupportedOperationError ())
280+
281+ task : Task | None = await self .task_store .get (params .taskId )
282+ if not task :
283+ raise ServerError (error = TaskNotFoundError ())
284+
285+ await self ._push_notifier .set_info (
286+ params .taskId ,
287+ params .pushNotificationConfig ,
288+ )
289+
290+ return params
230291
231292 async def on_get_task_push_notification_config (
232293 self , params : TaskIdParams
233294 ) -> TaskPushNotificationConfig :
234295 """Default handler for 'tasks/pushNotificationConfig/get'."""
235- raise ServerError (error = UnsupportedOperationError ())
296+ if not self ._push_notifier :
297+ raise ServerError (error = UnsupportedOperationError ())
298+
299+ task : Task | None = await self .task_store .get (params .id )
300+ if not task :
301+ raise ServerError (error = TaskNotFoundError ())
302+
303+ push_notification_config = await self ._push_notifier .get_info (params .id )
304+ if not push_notification_config :
305+ raise ServerError (error = InternalError ())
306+
307+ return TaskPushNotificationConfig (
308+ taskId = params .id , pushNotificationConfig = push_notification_config
309+ )
236310
237311 async def on_resubscribe_to_task (
238312 self , params : TaskIdParams
@@ -258,3 +332,10 @@ async def on_resubscribe_to_task(
258332 consumer = EventConsumer (queue )
259333 async for event in result_aggregator .consume_and_emit (consumer ):
260334 yield event
335+
336+ def should_add_push_info (self , params : MessageSendParams ) -> bool :
337+ return bool (
338+ self ._push_notifier
339+ and params .configuration
340+ and params .configuration .pushNotificationConfig
341+ )
0 commit comments