Skip to content

Commit 08bb1a1

Browse files
committed
Add push notification support
1 parent 224d4f5 commit 08bb1a1

File tree

11 files changed

+389
-53
lines changed

11 files changed

+389
-53
lines changed

examples/google_adk/birthday_planner/__main__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import click
77
import uvicorn
8+
89
from adk_agent_executor import ADKAgentExecutor
910
from dotenv import load_dotenv
1011

@@ -18,6 +19,7 @@
1819
AgentSkill,
1920
)
2021

22+
2123
load_dotenv()
2224

2325
logging.basicConfig()

examples/google_adk/birthday_planner/adk_agent_executor.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import asyncio
22
import logging
3-
from collections.abc import AsyncGenerator
4-
from typing import Any, AsyncIterable
3+
4+
from collections.abc import AsyncGenerator, AsyncIterable
5+
from typing import Any
56
from uuid import uuid4
67

78
import httpx
9+
810
from google.adk import Runner
911
from google.adk.agents import LlmAgent, RunConfig
1012
from google.adk.artifacts import InMemoryArtifactService
@@ -42,6 +44,7 @@
4244
from a2a.utils import get_text_parts
4345
from a2a.utils.errors import ServerError
4446

47+
4548
logger = logging.getLogger(__name__)
4649
logger.setLevel(logging.DEBUG)
4750

@@ -66,7 +69,7 @@ def __init__(self, calendar_agent_url):
6669
name='birthday_planner_agent',
6770
description='An agent that helps manage birthday parties.',
6871
after_tool_callback=self._handle_auth_required_task,
69-
instruction=f"""
72+
instruction="""
7073
You are an agent that helps plan birthday parties.
7174
7275
Your job as a party planner is to act as a sounding board and idea generator for
@@ -165,7 +168,7 @@ async def _process_request(
165168
task_updater.add_artifact(response)
166169
task_updater.complete()
167170
break
168-
elif calls := event.get_function_calls():
171+
if calls := event.get_function_calls():
169172
for call in calls:
170173
# Provide an update on what we're doing.
171174
if call.name == 'message_calendar_agent':
@@ -314,23 +317,21 @@ def convert_a2a_part_to_genai(part: Part) -> types.Part:
314317
part = part.root
315318
if isinstance(part, TextPart):
316319
return types.Part(text=part.text)
317-
elif isinstance(part, FilePart):
320+
if isinstance(part, FilePart):
318321
if isinstance(part.file, FileWithUri):
319322
return types.Part(
320323
file_data=types.FileData(
321324
file_uri=part.file.uri, mime_type=part.file.mime_type
322325
)
323326
)
324-
elif isinstance(part.file, FileWithBytes):
327+
if isinstance(part.file, FileWithBytes):
325328
return types.Part(
326329
inline_data=types.Blob(
327330
data=part.file.bytes, mime_type=part.file.mime_type
328331
)
329332
)
330-
else:
331-
raise ValueError(f'Unsupported file type: {type(part.file)}')
332-
else:
333-
raise ValueError(f'Unsupported part type: {type(part)}')
333+
raise ValueError(f'Unsupported file type: {type(part.file)}')
334+
raise ValueError(f'Unsupported part type: {type(part)}')
334335

335336

336337
def convert_genai_parts_to_a2a(parts: list[types.Part]) -> list[Part]:
@@ -346,14 +347,14 @@ def convert_genai_part_to_a2a(part: types.Part) -> Part:
346347
"""Convert a single Google GenAI Part type into an A2A Part type."""
347348
if part.text:
348349
return TextPart(text=part.text)
349-
elif part.file_data:
350+
if part.file_data:
350351
return FilePart(
351352
file=FileWithUri(
352353
uri=part.file_data.file_uri,
353354
mime_type=part.file_data.mime_type,
354355
)
355356
)
356-
elif part.inline_data:
357+
if part.inline_data:
357358
return Part(
358359
root=FilePart(
359360
file=FileWithBytes(
@@ -362,5 +363,4 @@ def convert_genai_part_to_a2a(part: types.Part) -> Part:
362363
)
363364
)
364365
)
365-
else:
366-
raise ValueError(f'Unsupported part type: {part}')
366+
raise ValueError(f'Unsupported part type: {part}')

examples/langgraph/__main__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import sys
33

44
import click
5+
import httpx
56

67
from agent import CurrencyAgent
78
from agent_executor import CurrencyAgentExecutor
89
from dotenv import load_dotenv
910

1011
from a2a.server.apps import A2AStarletteApplication
1112
from a2a.server.request_handlers import DefaultRequestHandler
12-
from a2a.server.tasks import InMemoryTaskStore
13+
from a2a.server.tasks import InMemoryPushNotifier, InMemoryTaskStore
1314
from a2a.types import (
1415
AgentAuthentication,
1516
AgentCapabilities,
@@ -29,9 +30,11 @@ def main(host: str, port: int):
2930
print('GOOGLE_API_KEY environment variable not set.')
3031
sys.exit(1)
3132

33+
client = httpx.AsyncClient()
3234
request_handler = DefaultRequestHandler(
3335
agent_executor=CurrencyAgentExecutor(),
3436
task_store=InMemoryTaskStore(),
37+
push_notifier=InMemoryPushNotifier(client),
3538
)
3639

3740
server = A2AStarletteApplication(

src/a2a/server/agent_execution/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(
1919
task_id: str | None = None,
2020
context_id: str | None = None,
2121
task: Task | None = None,
22-
related_tasks: list[Task] = None,
22+
related_tasks: list[Task] | None = None,
2323
):
2424
if related_tasks is None:
2525
related_tasks = []

src/a2a/server/events/in_memory_queue_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class InMemoryQueueManager(QueueManager):
1818
true scalable deployment.
1919
"""
2020

21-
def __init__(self):
21+
def __init__(self) -> None:
2222
self._task_queue: dict[str, EventQueue] = {}
2323
self._lock = asyncio.Lock()
2424

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
2-
import contextlib
32
import logging
3+
44
from collections.abc import AsyncGenerator
55
from typing import cast
66

@@ -10,12 +10,16 @@
1010
EventConsumer,
1111
EventQueue,
1212
InMemoryQueueManager,
13-
NoTaskQueue,
1413
QueueManager,
1514
TaskQueueExists,
1615
)
1716
from a2a.server.request_handlers.request_handler import RequestHandler
18-
from a2a.server.tasks import ResultAggregator, TaskManager, TaskStore
17+
from a2a.server.tasks import (
18+
PushNotifier,
19+
ResultAggregator,
20+
TaskManager,
21+
TaskStore,
22+
)
1923
from a2a.types import (
2024
InternalError,
2125
Message,
@@ -29,6 +33,7 @@
2933
)
3034
from a2a.utils.errors import ServerError
3135

36+
3237
logger = logging.getLogger(__name__)
3338

3439

@@ -42,10 +47,12 @@ def __init__(
4247
agent_executor: AgentExecutor,
4348
task_store: TaskStore,
4449
queue_manager: QueueManager | None = None,
50+
push_notifier: PushNotifier | None = None,
4551
) -> None:
4652
self.agent_executor = agent_executor
4753
self.task_store = task_store
4854
self._queue_manager = queue_manager or InMemoryQueueManager()
55+
self._push_notifier = push_notifier
4956
# TODO: Likely want an interface for managing this, like AgentExecutionManager.
5057
self._running_agents = {}
5158
self._running_agents_lock = asyncio.Lock()
@@ -116,6 +123,15 @@ async def on_message_send(
116123
task: Task | None = await task_manager.get_task()
117124
if task:
118125
task = task_manager.update_with_message(params.message, task)
126+
if (
127+
self._push_notifier
128+
and params.configuration
129+
and params.configuration.pushNotificationConfig
130+
and not params.configuration.blocking
131+
):
132+
await self._push_notifier.set_info(
133+
task.id, params.configuration.pushNotificationConfig
134+
)
119135
request_context = RequestContext(
120136
params,
121137
task.id if task else None,
@@ -173,6 +189,16 @@ async def on_message_send_stream(
173189
if task:
174190
task = task_manager.update_with_message(params.message, task)
175191

192+
if (
193+
self._push_notifier
194+
and params.configuration
195+
and params.configuration.pushNotificationConfig
196+
):
197+
await self._push_notifier.set_info(
198+
task.id, params.configuration.pushNotificationConfig
199+
)
200+
else:
201+
queue = EventQueue()
176202
result_aggregator = ResultAggregator(task_manager)
177203
request_context = RequestContext(
178204
params,
@@ -196,12 +222,26 @@ async def on_message_send_stream(
196222
# Now we know we have a Task, register the queue
197223
if isinstance(event, Task):
198224
try:
199-
await self._queue_manager.add(event.id, queue)
200-
task_id = event.id
225+
created_task: Task = event
226+
await self._queue_manager.add(created_task.id, queue)
227+
task_id = created_task.id
201228
except TaskQueueExists:
202229
logging.info(
203230
'Multiple Task objects created in event stream.'
204231
)
232+
if (
233+
self._push_notifier
234+
and params.configuration
235+
and params.configuration.pushNotificationConfig
236+
):
237+
await self._push_notifier.set_info(
238+
created_task.id,
239+
params.configuration.pushNotificationConfig,
240+
)
241+
if self._push_notifier and task_id:
242+
latest_task = await result_aggregator.current_result
243+
if isinstance(latest_task, Task):
244+
await self._push_notifier.send_notification(latest_task)
205245
yield event
206246
finally:
207247
await self._cleanup_producer(producer_task, task_id)
@@ -220,13 +260,38 @@ async def on_set_task_push_notification_config(
220260
self, params: TaskPushNotificationConfig
221261
) -> TaskPushNotificationConfig:
222262
"""Default handler for 'tasks/pushNotificationConfig/set'."""
223-
raise ServerError(error=UnsupportedOperationError())
263+
if not self._push_notifier:
264+
raise ServerError(error=UnsupportedOperationError())
265+
266+
task: Task | None = await self.task_store.get(params.taskId)
267+
if not task:
268+
raise ServerError(error=TaskNotFoundError())
269+
270+
await self._push_notifier.set_info(
271+
params.taskId,
272+
params.pushNotificationConfig,
273+
)
274+
275+
return params
224276

225277
async def on_get_task_push_notification_config(
226278
self, params: TaskIdParams
227279
) -> TaskPushNotificationConfig:
228280
"""Default handler for 'tasks/pushNotificationConfig/get'."""
229-
raise ServerError(error=UnsupportedOperationError())
281+
if not self._push_notifier:
282+
raise ServerError(error=UnsupportedOperationError())
283+
284+
task: Task | None = await self.task_store.get(params.id)
285+
if not task:
286+
raise ServerError(error=TaskNotFoundError())
287+
288+
push_notification_config = await self._push_notifier.get_info(params.id)
289+
if not push_notification_config:
290+
raise ServerError(error=TaskNotFoundError())
291+
292+
return TaskPushNotificationConfig(
293+
taskId=params.id, pushNotificationConfig=push_notification_config
294+
)
230295

231296
async def on_resubscribe_to_task(
232297
self, params: TaskIdParams

src/a2a/server/tasks/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
from a2a.server.tasks.inmemory_push_notifier import InMemoryPushNotifier
12
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
3+
from a2a.server.tasks.push_notifier import PushNotifier
24
from a2a.server.tasks.result_aggregator import ResultAggregator
35
from a2a.server.tasks.task_manager import TaskManager
46
from a2a.server.tasks.task_store import TaskStore
57
from a2a.server.tasks.task_updater import TaskUpdater
68

79

810
__all__ = [
11+
'InMemoryPushNotifier',
912
'InMemoryTaskStore',
13+
'PushNotifier',
1014
'ResultAggregator',
1115
'TaskManager',
1216
'TaskStore',
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import asyncio
2+
import logging
3+
4+
import httpx
5+
6+
from a2a.server.tasks.push_notifier import PushNotifier
7+
from a2a.types import PushNotificationConfig, Task
8+
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class InMemoryPushNotifier(PushNotifier):
14+
"""In-memory implementation of PushNotifier interface."""
15+
16+
def __init__(self, httpx_client: httpx.AsyncClient) -> None:
17+
self._client = httpx_client
18+
self.lock = asyncio.Lock()
19+
self._push_notification_infos: dict[str, PushNotificationConfig] = {}
20+
21+
async def set_info(
22+
self, task_id: str, notification_config: PushNotificationConfig
23+
):
24+
async with self.lock:
25+
self._push_notification_infos[task_id] = notification_config
26+
27+
async def get_info(self, task_id: str) -> PushNotificationConfig | None:
28+
async with self.lock:
29+
return self._push_notification_infos.get(task_id)
30+
31+
async def delete_info(self, task_id: str):
32+
async with self.lock:
33+
if task_id in self._push_notification_infos:
34+
del self._push_notification_infos[task_id]
35+
36+
async def send_notification(self, task: Task):
37+
push_info = await self.get_info(task.id)
38+
if not push_info:
39+
return
40+
url = push_info.url
41+
42+
try:
43+
response = await self._client.post(
44+
url, json=task.model_dump(exclude_none=True)
45+
)
46+
response.raise_for_status()
47+
logger.info(f'Push-notification sent for URL: {url}')
48+
except Exception as e:
49+
logger.error(f'Error sending push-notification: {e}')
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from abc import ABC, abstractmethod
2+
3+
from a2a.types import PushNotificationConfig, Task
4+
5+
6+
class PushNotifier(ABC):
7+
"""PushNotifier interface to store, retrieve push notification for tasks and send push notifications."""
8+
9+
@abstractmethod
10+
async def set_info(
11+
self, task_id: str, notification_config: PushNotificationConfig
12+
):
13+
pass
14+
15+
@abstractmethod
16+
async def get_info(self, task_id: str) -> PushNotificationConfig | None:
17+
pass
18+
19+
@abstractmethod
20+
async def delete_info(self, task_id: str):
21+
pass
22+
23+
@abstractmethod
24+
async def send_notification(self, task: Task):
25+
pass

0 commit comments

Comments
 (0)