Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/langgraph/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import sys

import click
import httpx

from agent import CurrencyAgent
from agent_executor import CurrencyAgentExecutor
from dotenv import load_dotenv

from a2a.server.apps import A2AStarletteApplication
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.tasks import InMemoryTaskStore
from a2a.server.tasks import InMemoryPushNotifier, InMemoryTaskStore
from a2a.types import (
AgentAuthentication,
AgentCapabilities,
Expand All @@ -29,9 +30,11 @@ def main(host: str, port: int):
print('GOOGLE_API_KEY environment variable not set.')
sys.exit(1)

client = httpx.AsyncClient()
request_handler = DefaultRequestHandler(
agent_executor=CurrencyAgentExecutor(),
task_store=InMemoryTaskStore(),
push_notifier=InMemoryPushNotifier(client),
)

server = A2AStarletteApplication(
Expand Down
75 changes: 70 additions & 5 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
TaskQueueExists,
)
from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.server.tasks import ResultAggregator, TaskManager, TaskStore
from a2a.server.tasks import (
PushNotifier,
ResultAggregator,
TaskManager,
TaskStore,
)
from a2a.types import (
InternalError,
Message,
Expand Down Expand Up @@ -42,10 +47,12 @@ def __init__(
agent_executor: AgentExecutor,
task_store: TaskStore,
queue_manager: QueueManager | None = None,
push_notifier: PushNotifier | None = None,
) -> None:
self.agent_executor = agent_executor
self.task_store = task_store
self._queue_manager = queue_manager or InMemoryQueueManager()
self._push_notifier = push_notifier
# TODO: Likely want an interface for managing this, like AgentExecutionManager.
self._running_agents = {}
self._running_agents_lock = asyncio.Lock()
Expand Down Expand Up @@ -116,6 +123,15 @@ async def on_message_send(
task: Task | None = await task_manager.get_task()
if task:
task = task_manager.update_with_message(params.message, task)
if (
self._push_notifier
and params.configuration
and params.configuration.pushNotificationConfig
and not params.configuration.blocking
):
await self._push_notifier.set_info(
task.id, params.configuration.pushNotificationConfig
)
request_context = RequestContext(
params,
task.id if task else None,
Expand Down Expand Up @@ -174,6 +190,16 @@ async def on_message_send_stream(
if task:
task = task_manager.update_with_message(params.message, task)

if (
self._push_notifier
and params.configuration
and params.configuration.pushNotificationConfig
):
await self._push_notifier.set_info(
task.id, params.configuration.pushNotificationConfig
)
else:
queue = EventQueue()
result_aggregator = ResultAggregator(task_manager)
request_context = RequestContext(
params,
Expand All @@ -198,12 +224,26 @@ async def on_message_send_stream(
# Now we know we have a Task, register the queue
if isinstance(event, Task):
try:
await self._queue_manager.add(event.id, queue)
task_id = event.id
created_task: Task = event
await self._queue_manager.add(created_task.id, queue)
task_id = created_task.id
except TaskQueueExists:
logging.info(
'Multiple Task objects created in event stream.'
)
if (
self._push_notifier
and params.configuration
and params.configuration.pushNotificationConfig
):
await self._push_notifier.set_info(
created_task.id,
params.configuration.pushNotificationConfig,
)
if self._push_notifier and task_id:
latest_task = await result_aggregator.current_result
if isinstance(latest_task, Task):
await self._push_notifier.send_notification(latest_task)
yield event
finally:
await self._cleanup_producer(producer_task, task_id)
Expand All @@ -222,13 +262,38 @@ async def on_set_task_push_notification_config(
self, params: TaskPushNotificationConfig
) -> TaskPushNotificationConfig:
"""Default handler for 'tasks/pushNotificationConfig/set'."""
raise ServerError(error=UnsupportedOperationError())
if not self._push_notifier:
raise ServerError(error=UnsupportedOperationError())

task: Task | None = await self.task_store.get(params.taskId)
if not task:
raise ServerError(error=TaskNotFoundError())

await self._push_notifier.set_info(
params.taskId,
params.pushNotificationConfig,
)

return params

async def on_get_task_push_notification_config(
self, params: TaskIdParams
) -> TaskPushNotificationConfig:
"""Default handler for 'tasks/pushNotificationConfig/get'."""
raise ServerError(error=UnsupportedOperationError())
if not self._push_notifier:
raise ServerError(error=UnsupportedOperationError())

task: Task | None = await self.task_store.get(params.id)
if not task:
raise ServerError(error=TaskNotFoundError())

push_notification_config = await self._push_notifier.get_info(params.id)
if not push_notification_config:
raise ServerError(error=InternalError())

return TaskPushNotificationConfig(
taskId=params.id, pushNotificationConfig=push_notification_config
)

async def on_resubscribe_to_task(
self, params: TaskIdParams
Expand Down
4 changes: 4 additions & 0 deletions src/a2a/server/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from a2a.server.tasks.inmemory_push_notifier import InMemoryPushNotifier
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.server.tasks.push_notifier import PushNotifier
from a2a.server.tasks.result_aggregator import ResultAggregator
from a2a.server.tasks.task_manager import TaskManager
from a2a.server.tasks.task_store import TaskStore
from a2a.server.tasks.task_updater import TaskUpdater


__all__ = [
'InMemoryPushNotifier',
'InMemoryTaskStore',
'PushNotifier',
'ResultAggregator',
'TaskManager',
'TaskStore',
Expand Down
49 changes: 49 additions & 0 deletions src/a2a/server/tasks/inmemory_push_notifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import asyncio
import logging

import httpx

from a2a.server.tasks.push_notifier import PushNotifier
from a2a.types import PushNotificationConfig, Task


logger = logging.getLogger(__name__)


class InMemoryPushNotifier(PushNotifier):
"""In-memory implementation of PushNotifier interface."""

def __init__(self, httpx_client: httpx.AsyncClient) -> None:
self._client = httpx_client
self.lock = asyncio.Lock()
self._push_notification_infos: dict[str, PushNotificationConfig] = {}

async def set_info(
self, task_id: str, notification_config: PushNotificationConfig
):
async with self.lock:
self._push_notification_infos[task_id] = notification_config

async def get_info(self, task_id: str) -> PushNotificationConfig | None:
async with self.lock:
return self._push_notification_infos.get(task_id)

async def delete_info(self, task_id: str):
async with self.lock:
if task_id in self._push_notification_infos:
del self._push_notification_infos[task_id]

async def send_notification(self, task: Task):
push_info = await self.get_info(task.id)
if not push_info:
return
url = push_info.url

try:
response = await self._client.post(
url, json=task.model_dump(mode='json', exclude_none=True)
)
response.raise_for_status()
logger.info(f'Push-notification sent for URL: {url}')
except Exception as e:
logger.error(f'Error sending push-notification: {e}')
25 changes: 25 additions & 0 deletions src/a2a/server/tasks/push_notifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod

from a2a.types import PushNotificationConfig, Task


class PushNotifier(ABC):
"""PushNotifier interface to store, retrieve push notification for tasks and send push notifications."""

@abstractmethod
async def set_info(
self, task_id: str, notification_config: PushNotificationConfig
):
pass

@abstractmethod
async def get_info(self, task_id: str) -> PushNotificationConfig | None:
pass

@abstractmethod
async def delete_info(self, task_id: str):
pass

@abstractmethod
async def send_notification(self, task: Task):
pass
Loading