diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index d73260da..20747808 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -1,16 +1,35 @@ +ACard +AClient AError ARequest +ARun AServer AStarlette +EUR +GBP +INR +JPY +JSONRPCt +Llm +aconnect adk +autouse +cla +cls +coc codegen +coro datamodel +dunders genai +gle inmemory langgraph lifecycles +linting oauthoidc opensource socio sse tagwords +vulnz diff --git a/.github/actions/spelling/excludes.txt b/.github/actions/spelling/excludes.txt index 0d6f82d4..7b4de3ec 100644 --- a/.github/actions/spelling/excludes.txt +++ b/.github/actions/spelling/excludes.txt @@ -85,3 +85,5 @@ ^\Q.github/workflows/linter.yaml\E$ \.gitignore\E$ \.vscode/ +noxfile.py +\.ruff.toml$ diff --git a/.github/conventional-commit-lint.yaml b/.github/conventional-commit-lint.yaml new file mode 100644 index 00000000..c967ffa6 --- /dev/null +++ b/.github/conventional-commit-lint.yaml @@ -0,0 +1,2 @@ +enabled: true +always_check_pr_title: true diff --git a/.github/linters/.jscpd.json b/.github/linters/.jscpd.json index 5e86d6d8..fb0f3b60 100644 --- a/.github/linters/.jscpd.json +++ b/.github/linters/.jscpd.json @@ -1,5 +1,5 @@ { - "ignore": ["**/.github/**", "**/.git/**", "**/tests/**"], + "ignore": ["**/.github/**", "**/.git/**", "**/tests/**", "**/examples/**"], "threshold": 3, "reporters": ["html", "markdown"] } diff --git a/.github/linters/.markdownlint.json b/.github/linters/.markdownlint.json new file mode 100644 index 00000000..5c6dcb9d --- /dev/null +++ b/.github/linters/.markdownlint.json @@ -0,0 +1,4 @@ +{ + "MD034": false, + "MD013": false +} diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml index 152340bc..2fcf37ae 100644 --- a/.github/workflows/linter.yaml +++ b/.github/workflows/linter.yaml @@ -61,3 +61,5 @@ jobs: VALIDATE_CHECKOV: false VALIDATE_JAVASCRIPT_STANDARD: false VALIDATE_TYPESCRIPT_STANDARD: false + VALIDATE_GIT_COMMITLINT: false + MARKDOWN_CONFIG_FILE: .markdownlint.json diff --git a/.gitignore b/.gitignore index 67bf01f1..4da52568 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,4 @@ __pycache__ .ruff_cache .venv coverage.xml -spec.json \ No newline at end of file +.nox diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 9feb80f7..257e8a0c 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -47,7 +47,7 @@ offensive, or harmful. This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of -representing a project or community include using an official project e-mail +representing a project or community include using an official project email address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. @@ -93,4 +93,4 @@ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html Note: A version of this file is also available in the -[New Project repo](https://github.com/google/new-project/blob/master/docs/code-of-conduct.md). +[New Project repository](https://github.com/google/new-project/blob/master/docs/code-of-conduct.md). diff --git a/examples/google_adk/__init__.py b/examples/google_adk/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/google_adk/birthday_planner/README.md b/examples/google_adk/birthday_planner/README.md index 8567f795..52edb373 100644 --- a/examples/google_adk/birthday_planner/README.md +++ b/examples/google_adk/birthday_planner/README.md @@ -12,16 +12,16 @@ This agent helps plan birthday parties. It has access to a Calendar Agent that i ## Running the example -1. Create the .env file with your API Key +1. Create the `.env` file with your API Key -```bash -echo "GOOGLE_API_KEY=your_api_key_here" > .env -``` + ```bash + echo "GOOGLE_API_KEY=your_api_key_here" > .env + ``` 2. Run the Calendar Agent. See examples/google_adk/calendar_agent. 3. Run the example -``` -uv run . -``` \ No newline at end of file + ```sh + uv run . + ``` diff --git a/examples/google_adk/birthday_planner/__init__.py b/examples/google_adk/birthday_planner/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/google_adk/birthday_planner/__main__.py b/examples/google_adk/birthday_planner/__main__.py index d4a0343d..16609b3e 100644 --- a/examples/google_adk/birthday_planner/__main__.py +++ b/examples/google_adk/birthday_planner/__main__.py @@ -5,6 +5,7 @@ import click import uvicorn + from adk_agent_executor import ADKAgentExecutor from dotenv import load_dotenv @@ -18,6 +19,7 @@ AgentSkill, ) + load_dotenv() logging.basicConfig() @@ -39,11 +41,12 @@ def wrapper(*args, **kwargs): ) def main(host: str, port: int, calendar_agent: str): # Verify an API key is set. Not required if using Vertex AI APIs, since those can use gcloud credentials. - if not os.getenv('GOOGLE_GENAI_USE_VERTEXAI') == 'TRUE': - if not os.getenv('GOOGLE_API_KEY'): - raise Exception( - 'GOOGLE_API_KEY environment variable not set and GOOGLE_GENAI_USE_VERTEXAI is not TRUE.' - ) + if os.getenv('GOOGLE_GENAI_USE_VERTEXAI') != 'TRUE' and not os.getenv( + 'GOOGLE_API_KEY' + ): + raise ValueError( + 'GOOGLE_API_KEY environment variable not set and GOOGLE_GENAI_USE_VERTEXAI is not TRUE.' + ) skill = AgentSkill( id='plan_parties', diff --git a/examples/google_adk/birthday_planner/adk_agent_executor.py b/examples/google_adk/birthday_planner/adk_agent_executor.py index 49f13e96..4cf6a569 100644 --- a/examples/google_adk/birthday_planner/adk_agent_executor.py +++ b/examples/google_adk/birthday_planner/adk_agent_executor.py @@ -1,10 +1,12 @@ import asyncio import logging -from collections.abc import AsyncGenerator -from typing import Any, AsyncIterable + +from collections.abc import AsyncGenerator, AsyncIterable +from typing import Any from uuid import uuid4 import httpx + from google.adk import Runner from google.adk.agents import LlmAgent, RunConfig from google.adk.artifacts import InMemoryArtifactService @@ -42,6 +44,7 @@ from a2a.utils import get_text_parts from a2a.utils.errors import ServerError + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -49,7 +52,7 @@ class A2ARunConfig(RunConfig): - """Custom override of ADK RunConfig to smuggle extra data through the event loop""" + """Custom override of ADK RunConfig to smuggle extra data through the event loop.""" model_config = ConfigDict( arbitrary_types_allowed=True, @@ -66,7 +69,7 @@ def __init__(self, calendar_agent_url): name='birthday_planner_agent', description='An agent that helps manage birthday parties.', after_tool_callback=self._handle_auth_required_task, - instruction=f""" + instruction=""" You are an agent that helps plan birthday parties. Your job as a party planner is to act as a sounding board and idea generator for @@ -96,7 +99,7 @@ def _run_agent( session_id, new_message: types.Content, task_updater: TaskUpdater, - ) -> AsyncGenerator[Event, None]: + ) -> AsyncGenerator[Event]: return self.runner.run_async( session_id=session_id, user_id='self', @@ -111,7 +114,7 @@ async def _handle_auth_required_task( tool_context: ToolContext, tool_response: dict, ) -> dict | None: - """Handle requests that return auth-required""" + """Handle requests that return auth-required.""" if tool.name != 'message_calendar_agent': return None if not tool_context.state.get('task_suspended'): @@ -165,7 +168,7 @@ async def _process_request( task_updater.add_artifact(response) task_updater.complete() break - elif calls := event.get_function_calls(): + if calls := event.get_function_calls(): for call in calls: # Provide an update on what we're doing. if call.name == 'message_calendar_agent': @@ -305,36 +308,34 @@ async def _get_agent_task(self, task_id) -> Task: def convert_a2a_parts_to_genai(parts: list[Part]) -> list[types.Part]: - """Convert a list of A2A Part types into a list of Google GenAI Part types.""" + """Convert a list of A2A Part types into a list of Google Gen AI Part types.""" return [convert_a2a_part_to_genai(part) for part in parts] def convert_a2a_part_to_genai(part: Part) -> types.Part: - """Convert a single A2A Part type into a Google GenAI Part type.""" + """Convert a single A2A Part type into a Google Gen AI Part type.""" part = part.root if isinstance(part, TextPart): return types.Part(text=part.text) - elif isinstance(part, FilePart): + if isinstance(part, FilePart): if isinstance(part.file, FileWithUri): return types.Part( file_data=types.FileData( file_uri=part.file.uri, mime_type=part.file.mime_type ) ) - elif isinstance(part.file, FileWithBytes): + if isinstance(part.file, FileWithBytes): return types.Part( inline_data=types.Blob( data=part.file.bytes, mime_type=part.file.mime_type ) ) - else: - raise ValueError(f'Unsupported file type: {type(part.file)}') - else: - raise ValueError(f'Unsupported part type: {type(part)}') + raise ValueError(f'Unsupported file type: {type(part.file)}') + raise ValueError(f'Unsupported part type: {type(part)}') def convert_genai_parts_to_a2a(parts: list[types.Part]) -> list[Part]: - """Convert a list of Google GenAI Part types into a list of A2A Part types.""" + """Convert a list of Google Gen AI Part types into a list of A2A Part types.""" return [ convert_genai_part_to_a2a(part) for part in parts @@ -343,17 +344,17 @@ def convert_genai_parts_to_a2a(parts: list[types.Part]) -> list[Part]: def convert_genai_part_to_a2a(part: types.Part) -> Part: - """Convert a single Google GenAI Part type into an A2A Part type.""" + """Convert a single Google Gen AI Part type into an A2A Part type.""" if part.text: return TextPart(text=part.text) - elif part.file_data: + if part.file_data: return FilePart( file=FileWithUri( uri=part.file_data.file_uri, mime_type=part.file_data.mime_type, ) ) - elif part.inline_data: + if part.inline_data: return Part( root=FilePart( file=FileWithBytes( @@ -362,5 +363,4 @@ def convert_genai_part_to_a2a(part: types.Part) -> Part: ) ) ) - else: - raise ValueError(f'Unsupported part type: {part}') + raise ValueError(f'Unsupported part type: {part}') diff --git a/examples/google_adk/calendar_agent/README.md b/examples/google_adk/calendar_agent/README.md index 7e2f8e7b..25b71513 100644 --- a/examples/google_adk/calendar_agent/README.md +++ b/examples/google_adk/calendar_agent/README.md @@ -14,14 +14,14 @@ This example shows how to create an A2A Server that uses an ADK-based Agent that 1. Create the .env file with your API Key and OAuth2.0 Client details -```bash -echo "GOOGLE_API_KEY=your_api_key_here" > .env -echo "GOOGLE_CLIENT_ID=your_client_id_here" >> .env -echo "GOOGLE_CLIENT_SECRET=your_client_secret_here" >> .env -``` + ```bash + echo "GOOGLE_API_KEY=your_api_key_here" > .env + echo "GOOGLE_CLIENT_ID=your_client_id_here" >> .env + echo "GOOGLE_CLIENT_SECRET=your_client_secret_here" >> .env + ``` 2. Run the example -``` -uv run . -``` + ```bash + uv run . + ``` diff --git a/examples/google_adk/calendar_agent/__init__.py b/examples/google_adk/calendar_agent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/langgraph/README.md b/examples/langgraph/README.md index e9851e3e..a6cdeb68 100644 --- a/examples/langgraph/README.md +++ b/examples/langgraph/README.md @@ -4,21 +4,19 @@ An example LangGraph agent that helps with currency conversion. ## Getting started -1. Extract the zip file and cd to examples folder - -2. Create an environment file with your API key: +1. Create an environment file with your API key: ```bash echo "GOOGLE_API_KEY=your_api_key_here" > .env ``` -3. Start the server +2. Start the server ```bash uv run . ``` -4. Run the test client +3. Run the test client ```bash uv run test_client.py diff --git a/examples/langgraph/__init__.py b/examples/langgraph/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/langgraph/agent.py b/examples/langgraph/agent.py index 409995cc..41323427 100644 --- a/examples/langgraph/agent.py +++ b/examples/langgraph/agent.py @@ -11,9 +11,10 @@ ) from langchain_core.tools import tool # type: ignore from langchain_google_genai import ChatGoogleGenerativeAI +from pydantic import BaseModel + from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent # type: ignore -from pydantic import BaseModel logger = logging.getLogger(__name__) diff --git a/noxfile.py b/noxfile.py index 4d8b3165..fd9569fb 100644 --- a/noxfile.py +++ b/noxfile.py @@ -121,7 +121,7 @@ def format(session): session.run( 'pyupgrade', '--exit-zero-even-if-changed', - '--py311-plus', + '--py313-plus', *lint_paths_py, ) session.run( diff --git a/src/a2a/server/events/event_consumer.py b/src/a2a/server/events/event_consumer.py index e1109402..73cc9c75 100644 --- a/src/a2a/server/events/event_consumer.py +++ b/src/a2a/server/events/event_consumer.py @@ -51,11 +51,13 @@ async def consume_all(self) -> AsyncGenerator[Event]: raise self._exception try: # We use a timeout when waiting for an event from the queue. - # This is required because it allows the loop to check if - # `self._exception` has been set by the `agent_task_callback`. + # This is required because it allows the loop to check if + # `self._exception` has been set by the `agent_task_callback`. # Without the timeout, loop might hang indefinitely if no events are # enqueued by the agent and the agent simply threw an exception - event = await asyncio.wait_for(self.queue.dequeue_event(), timeout=self._timeout) + event = await asyncio.wait_for( + self.queue.dequeue_event(), timeout=self._timeout + ) logger.debug( f'Dequeued event of type: {type(event)} in consume_all.' ) @@ -83,16 +85,12 @@ async def consume_all(self) -> AsyncGenerator[Event]: logger.debug('Stopping event consumption in consume_all.') self.queue.close() break - except asyncio.TimeoutError: + except TimeoutError: # continue polling until there is a final event continue except asyncio.QueueShutDown: break - - - - def agent_task_callback(self, agent_task: asyncio.Task[None]): - if agent_task.exception() is not None: - self._exception = agent_task.exception() \ No newline at end of file + if agent_task.exception() is not None: + self._exception = agent_task.exception() diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 4107cdf2..bf616c0c 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -1,6 +1,6 @@ import asyncio -import contextlib import logging + from collections.abc import AsyncGenerator from typing import cast @@ -10,7 +10,6 @@ EventConsumer, EventQueue, InMemoryQueueManager, - NoTaskQueue, QueueManager, TaskQueueExists, ) @@ -29,6 +28,7 @@ ) from a2a.utils.errors import ServerError + logger = logging.getLogger(__name__) @@ -161,7 +161,7 @@ async def on_message_send( async def on_message_send_stream( self, params: MessageSendParams - ) -> AsyncGenerator[Event, None]: + ) -> AsyncGenerator[Event]: """Default handler for 'message/stream'.""" task_manager = TaskManager( task_id=params.message.taskId, @@ -208,11 +208,11 @@ async def on_message_send_stream( finally: await self._cleanup_producer(producer_task, task_id) - async def _register_producer(self, task_id, producer_task): + async def _register_producer(self, task_id, producer_task) -> None: async with self._running_agents_lock: self._running_agents[task_id] = producer_task - async def _cleanup_producer(self, producer_task, task_id): + async def _cleanup_producer(self, producer_task, task_id) -> None: await producer_task await self._queue_manager.close(task_id) async with self._running_agents_lock: @@ -232,7 +232,7 @@ async def on_get_task_push_notification_config( async def on_resubscribe_to_task( self, params: TaskIdParams - ) -> AsyncGenerator[Event, None]: + ) -> AsyncGenerator[Event]: """Default handler for 'tasks/resubscribe'.""" task: Task | None = await self.task_store.get(params.id) if not task: diff --git a/src/a2a/server/tasks/result_aggregator.py b/src/a2a/server/tasks/result_aggregator.py index 9220fc86..94f27403 100644 --- a/src/a2a/server/tasks/result_aggregator.py +++ b/src/a2a/server/tasks/result_aggregator.py @@ -1,12 +1,13 @@ import asyncio import logging + from collections.abc import AsyncGenerator, AsyncIterator -from typing import Tuple from a2a.server.events import Event, EventConsumer from a2a.server.tasks.task_manager import TaskManager from a2a.types import Message, Task, TaskState, TaskStatusUpdateEvent + logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ async def current_result(self) -> Task | Message | None: async def consume_and_emit( self, consumer: EventConsumer - ) -> AsyncGenerator[Event, None]: + ) -> AsyncGenerator[Event]: """Processes the event stream and emits the same event stream out.""" async for event in consumer.consume_all(): await self.task_manager.process(event) @@ -54,7 +55,7 @@ async def consume_all( async def consume_and_break_on_interrupt( self, consumer: EventConsumer - ) -> Tuple[Task | Message | None, bool]: + ) -> tuple[Task | Message | None, bool]: """Process the event stream until completion or an interruptable state is encountered.""" event_stream = consumer.consume_all() interrupted = False @@ -64,7 +65,7 @@ async def consume_and_break_on_interrupt( return event, False await self.task_manager.process(event) if ( - isinstance(event, (Task, TaskStatusUpdateEvent)) + isinstance(event, Task | TaskStatusUpdateEvent) and event.status.state == TaskState.auth_required ): # auth-required is a special state: the message should be @@ -82,16 +83,8 @@ async def consume_and_break_on_interrupt( break return await self.task_manager.get_task(), interrupted - async def _continue_consuming(self, event_stream: AsyncIterator[Event]): + async def _continue_consuming( + self, event_stream: AsyncIterator[Event] + ) -> None: async for event in event_stream: await self.task_manager.process(event) - - # async def consume_and_emit_task( - # self, consumer: EventConsumer - # ) -> AsyncGenerator[Event, None]: - # """Processes the event stream and emits the current state of the task.""" - # async for event in consumer.consume_all(): - # if isinstance(event, Message): - # self._current_task_or_message = event - # break - # yield await self.task_manager.process(event) diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 2b5d2fe5..6d838d98 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -91,11 +91,10 @@ def validate(expression, error_message=None): def decorator(function): def wrapper(self, *args, **kwargs): if not expression(self): - if not error_message: - message = str(expression) - logger.error(f'Unsupported Operation: {error_message}') + final_message = error_message or str(expression) + logger.error(f'Unsupported Operation: {final_message}') raise ServerError( - UnsupportedOperationError(message=error_message) + UnsupportedOperationError(message=final_message) ) return function(self, *args, **kwargs) diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index c482fcf7..8adb41db 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -1,49 +1,53 @@ import unittest import unittest.async_case -from unittest.mock import AsyncMock, patch, MagicMock + +from collections.abc import AsyncGenerator +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + import pytest -from a2a.server.events.event_queue import EventQueue + from a2a.server.agent_execution import AgentExecutor -from a2a.utils.errors import ServerError +from a2a.server.events import ( + QueueManager, +) +from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import ( DefaultRequestHandler, JSONRPCHandler, ) -from a2a.server.events import ( - QueueManager, -) from a2a.server.tasks import TaskStore from a2a.types import ( - AgentCard, AgentCapabilities, + AgentCard, + Artifact, + CancelTaskRequest, + CancelTaskSuccessResponse, GetTaskRequest, GetTaskResponse, GetTaskSuccessResponse, - Task, - TaskQueryParams, JSONRPCErrorResponse, - TaskNotFoundError, - TaskIdParams, - CancelTaskRequest, - CancelTaskSuccessResponse, - UnsupportedOperationError, - SendMessageRequest, Message, MessageSendParams, + Part, + SendMessageRequest, SendMessageSuccessResponse, SendStreamingMessageRequest, SendStreamingMessageSuccessResponse, + Task, TaskArtifactUpdateEvent, + TaskIdParams, + TaskNotFoundError, + TaskQueryParams, + TaskResubscriptionRequest, + TaskState, + TaskStatus, TaskStatusUpdateEvent, - Artifact, - Part, TextPart, - TaskStatus, - TaskState, - TaskResubscriptionRequest, + UnsupportedOperationError, ) -from collections.abc import AsyncGenerator -from typing import Any +from a2a.utils.errors import ServerError + MINIMAL_TASK: dict[str, Any] = { 'id': 'task_123', @@ -316,7 +320,7 @@ async def streaming_coro(): assert isinstance( event.root, SendStreamingMessageSuccessResponse ) - assert collected_events[i].root.result == events[i] + assert event.root.result == events[i] mock_agent_executor.execute.assert_called_once() async def test_on_message_stream_new_message_existing_task_success( @@ -387,7 +391,7 @@ async def test_on_resubscribe_existing_task_success( request_handler = DefaultRequestHandler( mock_agent_executor, mock_task_store, mock_queue_manager ) - mock_agent_card = MagicMock(spec=AgentCard) + self.mock_agent_card = MagicMock(spec=AgentCard) handler = JSONRPCHandler(self.mock_agent_card, request_handler) mock_task = Task(**MINIMAL_TASK, history=[]) events: list[Any] = [