Skip to content

Commit c70d7c1

Browse files
swapydapymartimfasantos
authored andcommitted
feat: Add push notification support (a2aproject#9)
1 parent 44a6979 commit c70d7c1

File tree

6 files changed

+349
-12
lines changed

6 files changed

+349
-12
lines changed

examples/langgraph/__main__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import sys
33

44
import click
5+
import httpx
56

67
from agent import CurrencyAgent
78
from agent_executor import CurrencyAgentExecutor
89
from dotenv import load_dotenv
910

1011
from a2a.server.apps import A2AStarletteApplication
1112
from a2a.server.request_handlers import DefaultRequestHandler
12-
from a2a.server.tasks import InMemoryTaskStore
13+
from a2a.server.tasks import InMemoryPushNotifier, InMemoryTaskStore
1314
from a2a.types import (
1415
AgentAuthentication,
1516
AgentCapabilities,
@@ -29,9 +30,11 @@ def main(host: str, port: int):
2930
print('GOOGLE_API_KEY environment variable not set.')
3031
sys.exit(1)
3132

33+
client = httpx.AsyncClient()
3234
request_handler = DefaultRequestHandler(
3335
agent_executor=CurrencyAgentExecutor(),
3436
task_store=InMemoryTaskStore(),
37+
push_notifier=InMemoryPushNotifier(client),
3538
)
3639

3740
server = A2AStarletteApplication(

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,18 @@
1414
TaskQueueExists,
1515
)
1616
from a2a.server.request_handlers.request_handler import RequestHandler
17-
from a2a.server.tasks import ResultAggregator, TaskManager, TaskStore
17+
from a2a.server.tasks import (
18+
PushNotifier,
19+
ResultAggregator,
20+
TaskManager,
21+
TaskStore,
22+
)
1823
from a2a.types import (
1924
InternalError,
2025
Message,
26+
MessageSendConfiguration,
2127
MessageSendParams,
28+
PushNotificationConfig,
2229
Task,
2330
TaskIdParams,
2431
TaskNotFoundError,
@@ -44,10 +51,12 @@ def __init__(
4451
agent_executor: AgentExecutor,
4552
task_store: TaskStore,
4653
queue_manager: QueueManager | None = None,
54+
push_notifier: PushNotifier | None = None,
4755
) -> None:
4856
self.agent_executor = agent_executor
4957
self.task_store = task_store
5058
self._queue_manager = queue_manager or InMemoryQueueManager()
59+
self._push_notifier = push_notifier
5160
# TODO: Likely want an interface for managing this, like AgentExecutionManager.
5261
self._running_agents = {}
5362
self._running_agents_lock = asyncio.Lock()
@@ -118,6 +127,13 @@ async def on_message_send(
118127
task: Task | None = await task_manager.get_task()
119128
if task:
120129
task = task_manager.update_with_message(params.message, task)
130+
if self.should_add_push_info(params):
131+
assert isinstance(self._push_notifier, PushNotifier) # For typechecker
132+
assert isinstance(params.configuration, MessageSendConfiguration) # For typechecker
133+
assert isinstance(params.configuration.pushNotificationConfig, PushNotificationConfig) # For typechecker
134+
await self._push_notifier.set_info(
135+
task.id, params.configuration.pushNotificationConfig
136+
)
121137
request_context = RequestContext(
122138
params,
123139
task.id if task else None,
@@ -176,6 +192,15 @@ async def on_message_send_stream(
176192
if task:
177193
task = task_manager.update_with_message(params.message, task)
178194

195+
if self.should_add_push_info(params):
196+
assert isinstance(self._push_notifier, PushNotifier) # For typechecker
197+
assert isinstance(params.configuration, MessageSendConfiguration) # For typechecker
198+
assert isinstance(params.configuration.pushNotificationConfig, PushNotificationConfig) # For typechecker
199+
await self._push_notifier.set_info(
200+
task.id, params.configuration.pushNotificationConfig
201+
)
202+
else:
203+
queue = EventQueue()
179204
result_aggregator = ResultAggregator(task_manager)
180205
request_context = RequestContext(
181206
params,
@@ -202,12 +227,26 @@ async def on_message_send_stream(
202227
f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.'
203228
)
204229
try:
205-
await self._queue_manager.add(event.id, queue)
206-
task_id = event.id
230+
created_task: Task = event
231+
await self._queue_manager.add(created_task.id, queue)
232+
task_id = created_task.id
207233
except TaskQueueExists:
208234
logging.info(
209235
'Multiple Task objects created in event stream.'
210236
)
237+
if (
238+
self._push_notifier
239+
and params.configuration
240+
and params.configuration.pushNotificationConfig
241+
):
242+
await self._push_notifier.set_info(
243+
created_task.id,
244+
params.configuration.pushNotificationConfig,
245+
)
246+
if self._push_notifier and task_id:
247+
latest_task = await result_aggregator.current_result
248+
if isinstance(latest_task, Task):
249+
await self._push_notifier.send_notification(latest_task)
211250
yield event
212251
finally:
213252
await self._cleanup_producer(producer_task, task_id)
@@ -226,13 +265,38 @@ async def on_set_task_push_notification_config(
226265
self, params: TaskPushNotificationConfig
227266
) -> TaskPushNotificationConfig:
228267
"""Default handler for 'tasks/pushNotificationConfig/set'."""
229-
raise ServerError(error=UnsupportedOperationError())
268+
if not self._push_notifier:
269+
raise ServerError(error=UnsupportedOperationError())
270+
271+
task: Task | None = await self.task_store.get(params.taskId)
272+
if not task:
273+
raise ServerError(error=TaskNotFoundError())
274+
275+
await self._push_notifier.set_info(
276+
params.taskId,
277+
params.pushNotificationConfig,
278+
)
279+
280+
return params
230281

231282
async def on_get_task_push_notification_config(
232283
self, params: TaskIdParams
233284
) -> TaskPushNotificationConfig:
234285
"""Default handler for 'tasks/pushNotificationConfig/get'."""
235-
raise ServerError(error=UnsupportedOperationError())
286+
if not self._push_notifier:
287+
raise ServerError(error=UnsupportedOperationError())
288+
289+
task: Task | None = await self.task_store.get(params.id)
290+
if not task:
291+
raise ServerError(error=TaskNotFoundError())
292+
293+
push_notification_config = await self._push_notifier.get_info(params.id)
294+
if not push_notification_config:
295+
raise ServerError(error=InternalError())
296+
297+
return TaskPushNotificationConfig(
298+
taskId=params.id, pushNotificationConfig=push_notification_config
299+
)
236300

237301
async def on_resubscribe_to_task(
238302
self, params: TaskIdParams
@@ -258,3 +322,13 @@ async def on_resubscribe_to_task(
258322
consumer = EventConsumer(queue)
259323
async for event in result_aggregator.consume_and_emit(consumer):
260324
yield event
325+
326+
def should_add_push_info(self, params: MessageSendParams) -> bool:
327+
if (
328+
self._push_notifier
329+
and params.configuration
330+
and params.configuration.pushNotificationConfig
331+
):
332+
return True
333+
else:
334+
return False

src/a2a/server/tasks/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
from a2a.server.tasks.inmemory_push_notifier import InMemoryPushNotifier
12
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
3+
from a2a.server.tasks.push_notifier import PushNotifier
24
from a2a.server.tasks.result_aggregator import ResultAggregator
35
from a2a.server.tasks.task_manager import TaskManager
46
from a2a.server.tasks.task_store import TaskStore
57
from a2a.server.tasks.task_updater import TaskUpdater
68

79

810
__all__ = [
11+
'InMemoryPushNotifier',
912
'InMemoryTaskStore',
13+
'PushNotifier',
1014
'ResultAggregator',
1115
'TaskManager',
1216
'TaskStore',
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import asyncio
2+
import logging
3+
4+
import httpx
5+
6+
from a2a.server.tasks.push_notifier import PushNotifier
7+
from a2a.types import PushNotificationConfig, Task
8+
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class InMemoryPushNotifier(PushNotifier):
14+
"""In-memory implementation of PushNotifier interface."""
15+
16+
def __init__(self, httpx_client: httpx.AsyncClient) -> None:
17+
self._client = httpx_client
18+
self.lock = asyncio.Lock()
19+
self._push_notification_infos: dict[str, PushNotificationConfig] = {}
20+
21+
async def set_info(
22+
self, task_id: str, notification_config: PushNotificationConfig
23+
):
24+
async with self.lock:
25+
self._push_notification_infos[task_id] = notification_config
26+
27+
async def get_info(self, task_id: str) -> PushNotificationConfig | None:
28+
async with self.lock:
29+
return self._push_notification_infos.get(task_id)
30+
31+
async def delete_info(self, task_id: str):
32+
async with self.lock:
33+
if task_id in self._push_notification_infos:
34+
del self._push_notification_infos[task_id]
35+
36+
async def send_notification(self, task: Task):
37+
push_info = await self.get_info(task.id)
38+
if not push_info:
39+
return
40+
url = push_info.url
41+
42+
try:
43+
response = await self._client.post(
44+
url, json=task.model_dump(mode='json', exclude_none=True)
45+
)
46+
response.raise_for_status()
47+
logger.info(f'Push-notification sent for URL: {url}')
48+
except Exception as e:
49+
logger.error(f'Error sending push-notification: {e}')
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from abc import ABC, abstractmethod
2+
3+
from a2a.types import PushNotificationConfig, Task
4+
5+
6+
class PushNotifier(ABC):
7+
"""PushNotifier interface to store, retrieve push notification for tasks and send push notifications."""
8+
9+
@abstractmethod
10+
async def set_info(
11+
self, task_id: str, notification_config: PushNotificationConfig
12+
):
13+
pass
14+
15+
@abstractmethod
16+
async def get_info(self, task_id: str) -> PushNotificationConfig | None:
17+
pass
18+
19+
@abstractmethod
20+
async def delete_info(self, task_id: str):
21+
pass
22+
23+
@abstractmethod
24+
async def send_notification(self, task: Task):
25+
pass

0 commit comments

Comments
 (0)