diff --git a/examples/langgraph/__main__.py b/examples/langgraph/__main__.py index af06aeae..6f151ed3 100644 --- a/examples/langgraph/__main__.py +++ b/examples/langgraph/__main__.py @@ -2,6 +2,7 @@ import sys import click +import httpx from agent import CurrencyAgent from agent_executor import CurrencyAgentExecutor @@ -9,7 +10,7 @@ from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler -from a2a.server.tasks import InMemoryTaskStore +from a2a.server.tasks import InMemoryPushNotifier, InMemoryTaskStore from a2a.types import ( AgentAuthentication, AgentCapabilities, @@ -29,9 +30,11 @@ def main(host: str, port: int): print('GOOGLE_API_KEY environment variable not set.') sys.exit(1) + client = httpx.AsyncClient() request_handler = DefaultRequestHandler( agent_executor=CurrencyAgentExecutor(), task_store=InMemoryTaskStore(), + push_notifier=InMemoryPushNotifier(client), ) server = A2AStarletteApplication( diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index bf616c0c..f8501526 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -14,11 +14,18 @@ TaskQueueExists, ) from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.tasks import ResultAggregator, TaskManager, TaskStore +from a2a.server.tasks import ( + PushNotifier, + ResultAggregator, + TaskManager, + TaskStore, +) from a2a.types import ( InternalError, Message, + MessageSendConfiguration, MessageSendParams, + PushNotificationConfig, Task, TaskIdParams, TaskNotFoundError, @@ -42,10 +49,12 @@ def __init__( agent_executor: AgentExecutor, task_store: TaskStore, queue_manager: QueueManager | None = None, + push_notifier: PushNotifier | None = None, ) -> None: self.agent_executor = agent_executor self.task_store = task_store self._queue_manager = queue_manager or InMemoryQueueManager() + self._push_notifier = push_notifier # TODO: Likely want an interface for managing this, like AgentExecutionManager. self._running_agents = {} self._running_agents_lock = asyncio.Lock() @@ -116,6 +125,13 @@ async def on_message_send( task: Task | None = await task_manager.get_task() if task: task = task_manager.update_with_message(params.message, task) + if self.should_add_push_info(params): + assert isinstance(self._push_notifier, PushNotifier) # For typechecker + assert isinstance(params.configuration, MessageSendConfiguration) # For typechecker + assert isinstance(params.configuration.pushNotificationConfig, PushNotificationConfig) # For typechecker + await self._push_notifier.set_info( + task.id, params.configuration.pushNotificationConfig + ) request_context = RequestContext( params, task.id if task else None, @@ -174,6 +190,15 @@ async def on_message_send_stream( if task: task = task_manager.update_with_message(params.message, task) + if self.should_add_push_info(params): + assert isinstance(self._push_notifier, PushNotifier) # For typechecker + assert isinstance(params.configuration, MessageSendConfiguration) # For typechecker + assert isinstance(params.configuration.pushNotificationConfig, PushNotificationConfig) # For typechecker + await self._push_notifier.set_info( + task.id, params.configuration.pushNotificationConfig + ) + else: + queue = EventQueue() result_aggregator = ResultAggregator(task_manager) request_context = RequestContext( params, @@ -198,12 +223,26 @@ async def on_message_send_stream( # Now we know we have a Task, register the queue if isinstance(event, Task): try: - await self._queue_manager.add(event.id, queue) - task_id = event.id + created_task: Task = event + await self._queue_manager.add(created_task.id, queue) + task_id = created_task.id except TaskQueueExists: logging.info( 'Multiple Task objects created in event stream.' ) + if ( + self._push_notifier + and params.configuration + and params.configuration.pushNotificationConfig + ): + await self._push_notifier.set_info( + created_task.id, + params.configuration.pushNotificationConfig, + ) + if self._push_notifier and task_id: + latest_task = await result_aggregator.current_result + if isinstance(latest_task, Task): + await self._push_notifier.send_notification(latest_task) yield event finally: await self._cleanup_producer(producer_task, task_id) @@ -222,13 +261,38 @@ async def on_set_task_push_notification_config( self, params: TaskPushNotificationConfig ) -> TaskPushNotificationConfig: """Default handler for 'tasks/pushNotificationConfig/set'.""" - raise ServerError(error=UnsupportedOperationError()) + if not self._push_notifier: + raise ServerError(error=UnsupportedOperationError()) + + task: Task | None = await self.task_store.get(params.taskId) + if not task: + raise ServerError(error=TaskNotFoundError()) + + await self._push_notifier.set_info( + params.taskId, + params.pushNotificationConfig, + ) + + return params async def on_get_task_push_notification_config( self, params: TaskIdParams ) -> TaskPushNotificationConfig: """Default handler for 'tasks/pushNotificationConfig/get'.""" - raise ServerError(error=UnsupportedOperationError()) + if not self._push_notifier: + raise ServerError(error=UnsupportedOperationError()) + + task: Task | None = await self.task_store.get(params.id) + if not task: + raise ServerError(error=TaskNotFoundError()) + + push_notification_config = await self._push_notifier.get_info(params.id) + if not push_notification_config: + raise ServerError(error=InternalError()) + + return TaskPushNotificationConfig( + taskId=params.id, pushNotificationConfig=push_notification_config + ) async def on_resubscribe_to_task( self, params: TaskIdParams @@ -254,3 +318,13 @@ async def on_resubscribe_to_task( consumer = EventConsumer(queue) async for event in result_aggregator.consume_and_emit(consumer): yield event + + def should_add_push_info(self, params: MessageSendParams) -> bool: + if ( + self._push_notifier + and params.configuration + and params.configuration.pushNotificationConfig + ): + return True + else: + return False diff --git a/src/a2a/server/tasks/__init__.py b/src/a2a/server/tasks/__init__.py index d61df11f..4dc94947 100644 --- a/src/a2a/server/tasks/__init__.py +++ b/src/a2a/server/tasks/__init__.py @@ -1,4 +1,6 @@ +from a2a.server.tasks.inmemory_push_notifier import InMemoryPushNotifier from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.server.tasks.push_notifier import PushNotifier from a2a.server.tasks.result_aggregator import ResultAggregator from a2a.server.tasks.task_manager import TaskManager from a2a.server.tasks.task_store import TaskStore @@ -6,7 +8,9 @@ __all__ = [ + 'InMemoryPushNotifier', 'InMemoryTaskStore', + 'PushNotifier', 'ResultAggregator', 'TaskManager', 'TaskStore', diff --git a/src/a2a/server/tasks/inmemory_push_notifier.py b/src/a2a/server/tasks/inmemory_push_notifier.py new file mode 100644 index 00000000..222dc90f --- /dev/null +++ b/src/a2a/server/tasks/inmemory_push_notifier.py @@ -0,0 +1,49 @@ +import asyncio +import logging + +import httpx + +from a2a.server.tasks.push_notifier import PushNotifier +from a2a.types import PushNotificationConfig, Task + + +logger = logging.getLogger(__name__) + + +class InMemoryPushNotifier(PushNotifier): + """In-memory implementation of PushNotifier interface.""" + + def __init__(self, httpx_client: httpx.AsyncClient) -> None: + self._client = httpx_client + self.lock = asyncio.Lock() + self._push_notification_infos: dict[str, PushNotificationConfig] = {} + + async def set_info( + self, task_id: str, notification_config: PushNotificationConfig + ): + async with self.lock: + self._push_notification_infos[task_id] = notification_config + + async def get_info(self, task_id: str) -> PushNotificationConfig | None: + async with self.lock: + return self._push_notification_infos.get(task_id) + + async def delete_info(self, task_id: str): + async with self.lock: + if task_id in self._push_notification_infos: + del self._push_notification_infos[task_id] + + async def send_notification(self, task: Task): + push_info = await self.get_info(task.id) + if not push_info: + return + url = push_info.url + + try: + response = await self._client.post( + url, json=task.model_dump(mode='json', exclude_none=True) + ) + response.raise_for_status() + logger.info(f'Push-notification sent for URL: {url}') + except Exception as e: + logger.error(f'Error sending push-notification: {e}') diff --git a/src/a2a/server/tasks/push_notifier.py b/src/a2a/server/tasks/push_notifier.py new file mode 100644 index 00000000..10f01f3a --- /dev/null +++ b/src/a2a/server/tasks/push_notifier.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + +from a2a.types import PushNotificationConfig, Task + + +class PushNotifier(ABC): + """PushNotifier interface to store, retrieve push notification for tasks and send push notifications.""" + + @abstractmethod + async def set_info( + self, task_id: str, notification_config: PushNotificationConfig + ): + pass + + @abstractmethod + async def get_info(self, task_id: str) -> PushNotificationConfig | None: + pass + + @abstractmethod + async def delete_info(self, task_id: str): + pass + + @abstractmethod + async def send_notification(self, task: Task): + pass diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index 8adb41db..d5b94605 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -3,9 +3,10 @@ from collections.abc import AsyncGenerator from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch, call import pytest +import httpx from a2a.server.agent_execution import AgentExecutor from a2a.server.events import ( @@ -16,28 +17,37 @@ DefaultRequestHandler, JSONRPCHandler, ) -from a2a.server.tasks import TaskStore +from a2a.server.tasks import InMemoryPushNotifier, PushNotifier, TaskStore from a2a.types import ( AgentCapabilities, AgentCard, Artifact, CancelTaskRequest, CancelTaskSuccessResponse, + GetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigResponse, + GetTaskPushNotificationConfigSuccessResponse, GetTaskRequest, GetTaskResponse, GetTaskSuccessResponse, JSONRPCErrorResponse, Message, + MessageSendConfiguration, MessageSendParams, Part, SendMessageRequest, SendMessageSuccessResponse, + PushNotificationConfig, SendStreamingMessageRequest, SendStreamingMessageSuccessResponse, + SetTaskPushNotificationConfigRequest, + SetTaskPushNotificationConfigResponse, + SetTaskPushNotificationConfigSuccessResponse, Task, TaskArtifactUpdateEvent, TaskIdParams, TaskNotFoundError, + TaskPushNotificationConfig, TaskQueryParams, TaskResubscriptionRequest, TaskState, @@ -48,7 +58,6 @@ ) from a2a.utils.errors import ServerError - MINIMAL_TASK: dict[str, Any] = { 'id': 'task_123', 'contextId': 'session-xyz', @@ -375,13 +384,186 @@ async def streaming_coro(): ) response = handler.on_message_send_stream(request) assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) + collected_events = [item async for item in response] assert len(collected_events) == len(events) mock_agent_executor.execute.assert_called_once() assert mock_task.history is not None and len(mock_task.history) == 1 + async def test_set_push_notification_success(self) -> None: + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + mock_push_notifier = AsyncMock(spec=PushNotifier) + request_handler = DefaultRequestHandler( + mock_agent_executor, + mock_task_store, + push_notifier=mock_push_notifier, + ) + self.mock_agent_card.capabilities = AgentCapabilities( + streaming=True, pushNotifications=True + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + mock_task = Task(**MINIMAL_TASK) + mock_task_store.get.return_value = mock_task + task_push_config = TaskPushNotificationConfig( + taskId=mock_task.id, + pushNotificationConfig=PushNotificationConfig( + url='http://example.com' + ), + ) + request = SetTaskPushNotificationConfigRequest( + id='1', params=task_push_config + ) + response: SetTaskPushNotificationConfigResponse = ( + await handler.set_push_notification(request) + ) + self.assertIsInstance( + response.root, SetTaskPushNotificationConfigSuccessResponse + ) + assert response.root.result == task_push_config # type: ignore + mock_push_notifier.set_info.assert_called_once_with( + mock_task.id, task_push_config.pushNotificationConfig + ) + + async def test_get_push_notification_success(self) -> None: + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) + push_notifier = InMemoryPushNotifier(httpx_client=mock_httpx_client) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store, push_notifier=push_notifier + ) + self.mock_agent_card.capabilities = AgentCapabilities( + streaming=True, pushNotifications=True + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + mock_task = Task(**MINIMAL_TASK) + mock_task_store.get.return_value = mock_task + task_push_config = TaskPushNotificationConfig( + taskId=mock_task.id, + pushNotificationConfig=PushNotificationConfig( + url='http://example.com' + ), + ) + request = SetTaskPushNotificationConfigRequest( + id='1', params=task_push_config + ) + await handler.set_push_notification(request) + + get_request: GetTaskPushNotificationConfigRequest = ( + GetTaskPushNotificationConfigRequest( + id='1', params=TaskIdParams(id=mock_task.id) + ) + ) + get_response: GetTaskPushNotificationConfigResponse = ( + await handler.get_push_notification(get_request) + ) + self.assertIsInstance( + get_response.root, GetTaskPushNotificationConfigSuccessResponse + ) + assert get_response.root.result == task_push_config # type: ignore + + async def test_on_message_stream_new_message_send_push_notification_success( + self, + ) -> None: + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) + push_notifier = InMemoryPushNotifier(httpx_client=mock_httpx_client) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store, push_notifier=push_notifier + ) + self.mock_agent_card.capabilities = AgentCapabilities( + streaming=True, pushNotifications=True + ) + + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + events: list[Any] = [ + Task(**MINIMAL_TASK), + TaskArtifactUpdateEvent( + taskId='task_123', + contextId='session-xyz', + artifact=Artifact( + artifactId='11', parts=[Part(TextPart(text='text'))] + ), + ), + TaskStatusUpdateEvent( + taskId='task_123', + contextId='session-xyz', + status=TaskStatus(state=TaskState.completed), + final=True, + ), + ] + + async def streaming_coro(): + for event in events: + yield event + + with patch( + 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', + return_value=streaming_coro(), + ): + mock_task_store.get.return_value = None + mock_agent_executor.execute.return_value = None + mock_httpx_client.post.return_value = httpx.Response(200) + request = SendStreamingMessageRequest( + id='1', + params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)), + ) + request.params.configuration = MessageSendConfiguration( + acceptedOutputModes=['text'], + pushNotificationConfig=PushNotificationConfig( + url='http://example.com' + ), + ) + response = handler.on_message_send_stream(request) + assert isinstance(response, AsyncGenerator) + + collected_events = [item async for item in response] + assert len(collected_events) == len(events) + + calls = [ + call( + 'http://example.com', + json={ + 'contextId': 'session-xyz', + 'id': 'task_123', + 'status': {'state': 'submitted'}, + 'type': 'task', + }, + ), + call( + 'http://example.com', + json={ + 'artifacts': [ + { + 'artifactId': '11', + 'parts': [{'text': 'text', 'type': 'text'}], + } + ], + 'contextId': 'session-xyz', + 'id': 'task_123', + 'status': {'state': 'submitted'}, + 'type': 'task', + }, + ), + call( + 'http://example.com', + json={ + 'artifacts': [ + { + 'artifactId': '11', + 'parts': [{'text': 'text', 'type': 'text'}], + } + ], + 'contextId': 'session-xyz', + 'id': 'task_123', + 'status': {'state': 'completed'}, + 'type': 'task', + }, + ), + ] + mock_httpx_client.post.assert_has_calls(calls) + async def test_on_resubscribe_existing_task_success( self, ) -> None: