1414 TaskQueueExists ,
1515)
1616from a2a .server .request_handlers .request_handler import RequestHandler
17- from a2a .server .tasks import ResultAggregator , TaskManager , TaskStore
17+ from a2a .server .tasks import (
18+ PushNotifier ,
19+ ResultAggregator ,
20+ TaskManager ,
21+ TaskStore ,
22+ )
1823from a2a .types import (
1924 InternalError ,
2025 Message ,
26+ MessageSendConfiguration ,
2127 MessageSendParams ,
28+ PushNotificationConfig ,
2229 Task ,
2330 TaskIdParams ,
2431 TaskNotFoundError ,
@@ -44,10 +51,12 @@ def __init__(
4451 agent_executor : AgentExecutor ,
4552 task_store : TaskStore ,
4653 queue_manager : QueueManager | None = None ,
54+ push_notifier : PushNotifier | None = None ,
4755 ) -> None :
4856 self .agent_executor = agent_executor
4957 self .task_store = task_store
5058 self ._queue_manager = queue_manager or InMemoryQueueManager ()
59+ self ._push_notifier = push_notifier
5160 # TODO: Likely want an interface for managing this, like AgentExecutionManager.
5261 self ._running_agents = {}
5362 self ._running_agents_lock = asyncio .Lock ()
@@ -118,6 +127,13 @@ async def on_message_send(
118127 task : Task | None = await task_manager .get_task ()
119128 if task :
120129 task = task_manager .update_with_message (params .message , task )
130+ if self .should_add_push_info (params ):
131+ assert isinstance (self ._push_notifier , PushNotifier ) # For typechecker
132+ assert isinstance (params .configuration , MessageSendConfiguration ) # For typechecker
133+ assert isinstance (params .configuration .pushNotificationConfig , PushNotificationConfig ) # For typechecker
134+ await self ._push_notifier .set_info (
135+ task .id , params .configuration .pushNotificationConfig
136+ )
121137 request_context = RequestContext (
122138 params ,
123139 task .id if task else None ,
@@ -176,6 +192,15 @@ async def on_message_send_stream(
176192 if task :
177193 task = task_manager .update_with_message (params .message , task )
178194
195+ if self .should_add_push_info (params ):
196+ assert isinstance (self ._push_notifier , PushNotifier ) # For typechecker
197+ assert isinstance (params .configuration , MessageSendConfiguration ) # For typechecker
198+ assert isinstance (params .configuration .pushNotificationConfig , PushNotificationConfig ) # For typechecker
199+ await self ._push_notifier .set_info (
200+ task .id , params .configuration .pushNotificationConfig
201+ )
202+ else :
203+ queue = EventQueue ()
179204 result_aggregator = ResultAggregator (task_manager )
180205 request_context = RequestContext (
181206 params ,
@@ -202,12 +227,26 @@ async def on_message_send_stream(
202227 f'Agent generated task_id={ event .id } does not match the RequestContext task_id={ task_id } .'
203228 )
204229 try :
205- await self ._queue_manager .add (event .id , queue )
206- task_id = event .id
230+ created_task : Task = event
231+ await self ._queue_manager .add (created_task .id , queue )
232+ task_id = created_task .id
207233 except TaskQueueExists :
208234 logging .info (
209235 'Multiple Task objects created in event stream.'
210236 )
237+ if (
238+ self ._push_notifier
239+ and params .configuration
240+ and params .configuration .pushNotificationConfig
241+ ):
242+ await self ._push_notifier .set_info (
243+ created_task .id ,
244+ params .configuration .pushNotificationConfig ,
245+ )
246+ if self ._push_notifier and task_id :
247+ latest_task = await result_aggregator .current_result
248+ if isinstance (latest_task , Task ):
249+ await self ._push_notifier .send_notification (latest_task )
211250 yield event
212251 finally :
213252 await self ._cleanup_producer (producer_task , task_id )
@@ -226,13 +265,38 @@ async def on_set_task_push_notification_config(
226265 self , params : TaskPushNotificationConfig
227266 ) -> TaskPushNotificationConfig :
228267 """Default handler for 'tasks/pushNotificationConfig/set'."""
229- raise ServerError (error = UnsupportedOperationError ())
268+ if not self ._push_notifier :
269+ raise ServerError (error = UnsupportedOperationError ())
270+
271+ task : Task | None = await self .task_store .get (params .taskId )
272+ if not task :
273+ raise ServerError (error = TaskNotFoundError ())
274+
275+ await self ._push_notifier .set_info (
276+ params .taskId ,
277+ params .pushNotificationConfig ,
278+ )
279+
280+ return params
230281
231282 async def on_get_task_push_notification_config (
232283 self , params : TaskIdParams
233284 ) -> TaskPushNotificationConfig :
234285 """Default handler for 'tasks/pushNotificationConfig/get'."""
235- raise ServerError (error = UnsupportedOperationError ())
286+ if not self ._push_notifier :
287+ raise ServerError (error = UnsupportedOperationError ())
288+
289+ task : Task | None = await self .task_store .get (params .id )
290+ if not task :
291+ raise ServerError (error = TaskNotFoundError ())
292+
293+ push_notification_config = await self ._push_notifier .get_info (params .id )
294+ if not push_notification_config :
295+ raise ServerError (error = InternalError ())
296+
297+ return TaskPushNotificationConfig (
298+ taskId = params .id , pushNotificationConfig = push_notification_config
299+ )
236300
237301 async def on_resubscribe_to_task (
238302 self , params : TaskIdParams
@@ -258,3 +322,13 @@ async def on_resubscribe_to_task(
258322 consumer = EventConsumer (queue )
259323 async for event in result_aggregator .consume_and_emit (consumer ):
260324 yield event
325+
326+ def should_add_push_info (self , params : MessageSendParams ) -> bool :
327+ if (
328+ self ._push_notifier
329+ and params .configuration
330+ and params .configuration .pushNotificationConfig
331+ ):
332+ return True
333+ else :
334+ return False
0 commit comments