diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index e69de29b..d28188e3 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -0,0 +1,12 @@ +AError +ARequest +AStarlette +adk +genai +inmemory +langgraph +lifecycles +oauthoidc +opensource +socio +sse diff --git a/.github/linters/.jscpd.json b/.github/linters/.jscpd.json index 0d12b3c1..5e86d6d8 100644 --- a/.github/linters/.jscpd.json +++ b/.github/linters/.jscpd.json @@ -1,8 +1,5 @@ { - "ignore": [ - "**/.github/**", - "**/.git/**" - ], + "ignore": ["**/.github/**", "**/.git/**", "**/tests/**"], "threshold": 3, "reporters": ["html", "markdown"] } diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml index d4965995..152340bc 100644 --- a/.github/workflows/linter.yaml +++ b/.github/workflows/linter.yaml @@ -14,9 +14,6 @@ name: Lint Code Base # Start the job on all push # ############################# on: - push: - branches-ignore: [main] - # Remove the line above to run when pushing to main pull_request: branches: [main] @@ -31,7 +28,7 @@ jobs: runs-on: ubuntu-latest # if on repo to avoid failing runs on forks if: | - github.repository == 'google/A2A' + github.repository == 'google/a2a-python' ################## # Load all steps # diff --git a/.github/workflows/spelling.yaml b/.github/workflows/spelling.yaml index 131d87b1..72516608 100644 --- a/.github/workflows/spelling.yaml +++ b/.github/workflows/spelling.yaml @@ -1,67 +1,6 @@ name: Check Spelling -# Comment management is handled through a secondary job, for details see: -# https://github.com/check-spelling/check-spelling/wiki/Feature%3A-Restricted-Permissions -# -# `jobs.comment-push` runs when a push is made to a repository and the `jobs.spelling` job needs to make a comment -# (in odd cases, it might actually run just to collapse a comment, but that's fairly rare) -# it needs `contents: write` in order to add a comment. -# -# `jobs.comment-pr` runs when a pull_request is made to a repository and the `jobs.spelling` job needs to make a comment -# or collapse a comment (in the case where it had previously made a comment and now no longer needs to show a comment) -# it needs `pull-requests: write` in order to manipulate those comments. - -# Updating pull request branches is managed via comment handling. -# For details, see: https://github.com/check-spelling/check-spelling/wiki/Feature:-Update-expect-list -# -# These elements work together to make it happen: -# -# `on.issue_comment` -# This event listens to comments by users asking to update the metadata. -# -# `jobs.update` -# This job runs in response to an issue_comment and will push a new commit -# to update the spelling metadata. -# -# `with.experimental_apply_changes_via_bot` -# Tells the action to support and generate messages that enable it -# to make a commit to update the spelling metadata. -# -# `with.ssh_key` -# In order to trigger workflows when the commit is made, you can provide a -# secret (typically, a write-enabled github deploy key). -# -# For background, see: https://github.com/check-spelling/check-spelling/wiki/Feature:-Update-with-deploy-key - -# Sarif reporting -# -# Access to Sarif reports is generally restricted (by GitHub) to members of the repository. -# -# Requires enabling `security-events: write` -# and configuring the action with `use_sarif: 1` -# -# For information on the feature, see: https://github.com/check-spelling/check-spelling/wiki/Feature:-Sarif-output - -# Minimal workflow structure: -# -# on: -# push: -# ... -# pull_request_target: -# ... -# jobs: -# # you only want the spelling job, all others should be omitted -# spelling: -# # remove `security-events: write` and `use_sarif: 1` -# # remove `experimental_apply_changes_via_bot: 1` -# ... otherwise adjust the `with:` as you wish - on: - push: - branches: - - "**" - tags-ignore: - - "**" pull_request: branches: - "**" @@ -85,7 +24,7 @@ jobs: runs-on: ubuntu-latest # if on repo to avoid failing runs on forks if: | - github.repository == 'google/A2A' + github.repository == 'google/a2a-python' && (contains(github.event_name, 'pull_request') || github.event_name == 'push') concurrency: group: spelling-${{ github.event.pull_request.number || github.ref }} @@ -141,6 +80,6 @@ jobs: cspell:sql/src/tsql.txt cspell:terraform/dict/terraform.txt cspell:typescript/dict/typescript.txt - check_extra_dictionaries: '' + check_extra_dictionaries: "" only_check_changed_files: true longest_word: "10" diff --git a/.vscode/launch.json b/.vscode/launch.json index d78cdb25..37651238 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,41 +1,31 @@ { - "version": "0.2.0", - "configurations": [ - { - "name": "Debug HelloWorld Agent", - "type": "debugpy", - "request": "launch", - "program": "${workspaceFolder}/examples/helloworld/__main__.py", - "console": "integratedTerminal", - "justMyCode": false, - "env": { - "PYTHONPATH": "${workspaceFolder}", - }, - "cwd": "${workspaceFolder}/examples/helloworld", - "args": [ - "--host", - "localhost", - "--port", - "9999" - ] - }, - { - "name": "Debug Currency Agent", - "type": "debugpy", - "request": "launch", - "program": "${workspaceFolder}/examples/langgraph/__main__.py", - "console": "integratedTerminal", - "justMyCode": false, - "env": { - "PYTHONPATH": "${workspaceFolder}", - }, - "cwd": "${workspaceFolder}/examples/langgraph", - "args": [ - "--host", - "localhost", - "--port", - "10000" - ] - } - ] -} \ No newline at end of file + "version": "0.2.0", + "configurations": [ + { + "name": "Debug HelloWorld Agent", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/examples/helloworld/__main__.py", + "console": "integratedTerminal", + "justMyCode": false, + "env": { + "PYTHONPATH": "${workspaceFolder}" + }, + "cwd": "${workspaceFolder}/examples/helloworld", + "args": ["--host", "localhost", "--port", "9999"] + }, + { + "name": "Debug Currency Agent", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/examples/langgraph/__main__.py", + "console": "integratedTerminal", + "justMyCode": false, + "env": { + "PYTHONPATH": "${workspaceFolder}" + }, + "cwd": "${workspaceFolder}/examples/langgraph", + "args": ["--host", "localhost", "--port", "10000"] + } + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json index 54eb8c58..3ffee4e7 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,16 +1,14 @@ { - "python.testing.pytestArgs": [ - "tests" - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, + "python.testing.pytestArgs": ["tests"], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "editor.formatOnSave": true, + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff", "editor.formatOnSave": true, - "[python]": { - "editor.defaultFormatter": "charliermarsh.ruff", - "editor.formatOnSave": true, - "editor.codeActionsOnSave": { - "source.organizeImports": "always" - }, - }, - "ruff.importStrategy": "fromEnvironment" -} \ No newline at end of file + "editor.codeActionsOnSave": { + "source.organizeImports": "always" + } + }, + "ruff.importStrategy": "fromEnvironment" +} diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index d5f2898f..df234019 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,4 +1,3 @@ - # Code of Conduct ## Our Pledge @@ -15,22 +14,22 @@ race, religion, or sexual identity and orientation. Examples of behavior that contributes to creating a positive environment include: -* Using welcoming and inclusive language -* Being respectful of differing viewpoints and experiences -* Gracefully accepting constructive criticism -* Focusing on what is best for the community -* Showing empathy towards other community members +- Using welcoming and inclusive language +- Being respectful of differing viewpoints and experiences +- Gracefully accepting constructive criticism +- Focusing on what is best for the community +- Showing empathy towards other community members Examples of unacceptable behavior by participants include: -* The use of sexualized language or imagery and unwelcome sexual attention or - advances -* Trolling, insulting/derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or electronic - address, without explicit permission -* Other conduct which could reasonably be considered inappropriate in a - professional setting +- The use of sexualized language or imagery and unwelcome sexual attention or + advances +- Trolling, insulting/derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or electronic + address, without explicit permission +- Other conduct which could reasonably be considered inappropriate in a + professional setting ## Our Responsibilities @@ -61,7 +60,7 @@ negative impact on the project or its community. We do not believe that all conflict is bad; healthy debate and disagreement often yield positive results. However, it is never okay to be disrespectful or -to engage in behavior that violates the project’s code of conduct. +to engage in behavior that violates the project's code of conduct. If you see someone violating the code of conduct, you are encouraged to address the behavior directly with those involved. Many issues can be resolved quickly @@ -70,8 +69,8 @@ dispute. If you are unable to resolve the matter for any reason, or if the behavior is threatening or harassing, report it. We are dedicated to providing an environment where participants feel welcome and safe. -Reports should be directed to *[PROJECT STEWARD NAME(s) AND EMAIL(s)]*, the -Project Steward(s) for *[PROJECT NAME]*. It is the Project Steward’s duty to +Reports should be directed to _[PROJECT STEWARD NAME(s) AND EMAIL(s)]_, the +Project Steward(s) for _[PROJECT NAME]_. It is the Project Steward's duty to receive and address reported violations of the code of conduct. They will then work with a committee consisting of representatives from the Open Source Programs Office and the Google Open Source Strategy team. If for any reason you diff --git a/examples/google_adk/calendar_agent/README.md b/examples/google_adk/calendar_agent/README.md index 664c2c77..7e2f8e7b 100644 --- a/examples/google_adk/calendar_agent/README.md +++ b/examples/google_adk/calendar_agent/README.md @@ -8,7 +8,7 @@ This example shows how to create an A2A Server that uses an ADK-based Agent that - [UV](https://docs.astral.sh/uv/) - A Gemini API Key - A [Google OAuth Client](https://developers.google.com/identity/openid-connect/openid-connect#getcredentials) - - Configure your OAuth client to handle redirect URLs at `localhost:10007/authenticate` + - Configure your OAuth client to handle redirect URLs at `localhost:10007/authenticate` ## Running the example @@ -24,4 +24,4 @@ echo "GOOGLE_CLIENT_SECRET=your_client_secret_here" >> .env ``` uv run . -``` \ No newline at end of file +``` diff --git a/examples/google_adk/calendar_agent/__main__.py b/examples/google_adk/calendar_agent/__main__.py index a62ba813..823a14aa 100644 --- a/examples/google_adk/calendar_agent/__main__.py +++ b/examples/google_adk/calendar_agent/__main__.py @@ -1,31 +1,31 @@ import logging import os -import sys import click import uvicorn -from adk_agent_executor import ADKAgentExecutor from adk_agent import create_agent +from adk_agent_executor import ADKAgentExecutor from dotenv import load_dotenv +from google.adk.artifacts import InMemoryArtifactService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.runners import Runner +from google.adk.sessions import InMemorySessionService +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import PlainTextResponse +from starlette.routing import Route -from a2a.server.tasks import InMemoryTaskStore from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore from a2a.types import ( AgentAuthentication, AgentCapabilities, AgentCard, AgentSkill, ) -from google.adk.artifacts import InMemoryArtifactService -from google.adk.memory.in_memory_memory_service import InMemoryMemoryService -from google.adk.runners import Runner -from google.adk.sessions import InMemorySessionService -from starlette.applications import Starlette -from starlette.routing import Route -from starlette.requests import Request -from starlette.responses import PlainTextResponse + load_dotenv() @@ -36,12 +36,15 @@ @click.option('--host', 'host', default='localhost') @click.option('--port', 'port', default=10007) def main(host: str, port: int): - # Verify an API key is set. Not required if using Vertex AI APIs, since those can use gcloud credentials. - if not os.getenv('GOOGLE_GENAI_USE_VERTEXAI') == 'TRUE': - if not os.getenv('GOOGLE_API_KEY'): - raise Exception( - 'GOOGLE_API_KEY environment variable not set and GOOGLE_GENAI_USE_VERTEXAI is not TRUE.' - ) + # Verify an API key is set. + # Not required if using Vertex AI APIs. + if os.getenv('GOOGLE_GENAI_USE_VERTEXAI') != 'TRUE' and not os.getenv( + 'GOOGLE_API_KEY' + ): + raise ValueError( + 'GOOGLE_API_KEY environment variable not set and ' + 'GOOGLE_GENAI_USE_VERTEXAI is not TRUE.' + ) skill = AgentSkill( id='check_availability', diff --git a/examples/google_adk/calendar_agent/adk_agent_executor.py b/examples/google_adk/calendar_agent/adk_agent_executor.py index 53ff677a..9182227d 100644 --- a/examples/google_adk/calendar_agent/adk_agent_executor.py +++ b/examples/google_adk/calendar_agent/adk_agent_executor.py @@ -1,8 +1,8 @@ import asyncio import logging + from collections import namedtuple from collections.abc import AsyncGenerator -from typing import Dict from urllib.parse import parse_qs, urlparse from google.adk import Runner @@ -27,6 +27,7 @@ from a2a.utils.errors import ServerError from a2a.utils.message import new_agent_text_message + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -42,7 +43,7 @@ class ADKAgentExecutor(AgentExecutor): """An AgentExecutor that runs an ADK-based Agent.""" - _awaiting_auth: Dict[str, asyncio.Future] + _awaiting_auth: dict[str, asyncio.Future] def __init__(self, runner: Runner, card: AgentCard): self.runner = runner @@ -62,7 +63,7 @@ async def _process_request( new_message: types.Content, session_id: str, task_updater: TaskUpdater, - ): + ) -> None: session_id = self._upsert_session( session_id, ).id @@ -95,13 +96,13 @@ async def _process_request( # Break out of event handling loop -- no more work will be done until the authorization # is received. break - elif event.is_final_response(): + if event.is_final_response(): parts = convert_genai_parts_to_a2a(event.content.parts) logger.debug('Yielding final response: %s', parts) task_updater.add_artifact(parts) task_updater.complete() break - elif not event.get_function_calls(): + if not event.get_function_calls(): logger.debug('Yielding update response') task_updater.update_status( TaskState.working, @@ -159,7 +160,7 @@ async def _complete_auth_processing( session_id: str, auth_details: ADKAuthDetails, task_updater: TaskUpdater, - ): + ) -> None: logger.debug('Waiting for auth event') try: auth_uri = await asyncio.wait_for( @@ -236,36 +237,34 @@ def _upsert_session(self, session_id: str): def convert_a2a_parts_to_genai(parts: list[Part]) -> list[types.Part]: - """Convert a list of A2A Part types into a list of Google GenAI Part types.""" + """Convert a list of A2A Part types into a list of Google Gen AI Part types.""" return [convert_a2a_part_to_genai(part) for part in parts] def convert_a2a_part_to_genai(part: Part) -> types.Part: - """Convert a single A2A Part type into a Google GenAI Part type.""" + """Convert a single A2A Part type into a Google Gen AI Part type.""" part = part.root if isinstance(part, TextPart): return types.Part(text=part.text) - elif isinstance(part, FilePart): + if isinstance(part, FilePart): if isinstance(part.file, FileWithUri): return types.Part( file_data=types.FileData( file_uri=part.file.uri, mime_type=part.file.mime_type ) ) - elif isinstance(part.file, FileWithBytes): + if isinstance(part.file, FileWithBytes): return types.Part( inline_data=types.Blob( data=part.file.bytes, mime_type=part.file.mime_type ) ) - else: - raise ValueError(f'Unsupported file type: {type(part.file)}') - else: - raise ValueError(f'Unsupported part type: {type(part)}') + raise ValueError(f'Unsupported file type: {type(part.file)}') + raise ValueError(f'Unsupported part type: {type(part)}') def convert_genai_parts_to_a2a(parts: list[types.Part]) -> list[Part]: - """Convert a list of Google GenAI Part types into a list of A2A Part types.""" + """Convert a list of Google Gen AI Part types into a list of A2A Part types.""" return [ convert_genai_part_to_a2a(part) for part in parts @@ -274,17 +273,17 @@ def convert_genai_parts_to_a2a(parts: list[types.Part]) -> list[Part]: def convert_genai_part_to_a2a(part: types.Part) -> Part: - """Convert a single Google GenAI Part type into an A2A Part type.""" + """Convert a single Google Gen AI Part type into an A2A Part type.""" if part.text: return TextPart(text=part.text) - elif part.file_data: + if part.file_data: return FilePart( file=FileWithUri( uri=part.file_data.file_uri, mime_type=part.file_data.mime_type, ) ) - elif part.inline_data: + if part.inline_data: return Part( root=FilePart( file=FileWithBytes( @@ -293,14 +292,13 @@ def convert_genai_part_to_a2a(part: types.Part) -> Part: ) ) ) - else: - raise ValueError(f'Unsupported part type: {part}') + raise ValueError(f'Unsupported part type: {part}') def get_auth_request_function_call(event: Event) -> types.FunctionCall: - """Get the special auth request function call from the event""" + """Get the special auth request function call from the event.""" if not (event.content and event.content.parts): - return + return None for part in event.content.parts: if ( part @@ -310,12 +308,13 @@ def get_auth_request_function_call(event: Event) -> types.FunctionCall: and part.function_call.id in event.long_running_tool_ids ): return part.function_call + return None def get_auth_config( auth_request_function_call: types.FunctionCall, ) -> AuthConfig: - """Extracts the AuthConfig object from the arguments of the auth request function call""" + """Extracts the AuthConfig object from the arguments of the auth request function call.""" if not auth_request_function_call.args or not ( auth_config := auth_request_function_call.args.get('auth_config') ): diff --git a/examples/helloworld/README.md b/examples/helloworld/README.md index 96aece24..1e82c148 100644 --- a/examples/helloworld/README.md +++ b/examples/helloworld/README.md @@ -1,13 +1,17 @@ +# Hello World Example + Hello World example agent that only returns Message events ## Getting started 1. Start the server - ```bash - uv run . - ``` - -4. Run the test client - ```bash - uv run test_client.py - ``` \ No newline at end of file + + ```bash + uv run . + ``` + +2. Run the test client + + ```bash + uv run test_client.py + ``` diff --git a/examples/helloworld/agent_executor.py b/examples/helloworld/agent_executor.py index 92b9b72e..ed13dbc7 100644 --- a/examples/helloworld/agent_executor.py +++ b/examples/helloworld/agent_executor.py @@ -2,9 +2,6 @@ from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue -from a2a.types import ( - Task, -) from a2a.utils import new_agent_text_message diff --git a/examples/langgraph/README.md b/examples/langgraph/README.md index 3b46568f..f66e9808 100644 --- a/examples/langgraph/README.md +++ b/examples/langgraph/README.md @@ -1,20 +1,25 @@ -An example langgraph agent that helps with currency conversion. +# LangGraph example + +An example LangGraph agent that helps with currency conversion. ## Getting started 1. Extract the zip file and cd to examples folder 2. Create an environment file with your API key: + ```bash echo "GOOGLE_API_KEY=your_api_key_here" > .env ``` 3. Start the server - ```bash - uv run main.py - ``` + + ```bash + uv run main.py + ``` 4. Run the test client - ```bash - uv run test_client.py - ``` \ No newline at end of file + + ```bash + uv run test_client.py + ``` diff --git a/src/a2a/server/agent_execution/__init__.py b/src/a2a/server/agent_execution/__init__.py index 7638fc02..88660d62 100644 --- a/src/a2a/server/agent_execution/__init__.py +++ b/src/a2a/server/agent_execution/__init__.py @@ -1,4 +1,5 @@ from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext + __all__ = ['AgentExecutor', 'RequestContext'] diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index 986e5ce3..3c61dc6b 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -5,7 +5,6 @@ Message, MessageSendParams, Task, - TextPart, ) from a2a.utils import get_message_text from a2a.utils.errors import ServerError @@ -20,8 +19,10 @@ def __init__( task_id: str | None = None, context_id: str | None = None, task: Task | None = None, - related_tasks: list[Task] = [], + related_tasks: list[Task] = None, ): + if related_tasks is None: + related_tasks = [] self._params = request self._task_id = task_id self._context_id = context_id diff --git a/src/a2a/server/apps/__init__.py b/src/a2a/server/apps/__init__.py index 2f984cc1..76b6e465 100644 --- a/src/a2a/server/apps/__init__.py +++ b/src/a2a/server/apps/__init__.py @@ -2,4 +2,4 @@ from a2a.server.apps.starlette_app import A2AStarletteApplication -__all__ = ['HttpApp', 'A2AStarletteApplication'] +__all__ = ['A2AStarletteApplication', 'HttpApp'] diff --git a/src/a2a/server/apps/starlette_app.py b/src/a2a/server/apps/starlette_app.py index 6f6a6a18..5bbdcced 100644 --- a/src/a2a/server/apps/starlette_app.py +++ b/src/a2a/server/apps/starlette_app.py @@ -1,9 +1,17 @@ -from collections.abc import AsyncGenerator import json import logging import traceback + +from collections.abc import AsyncGenerator from typing import Any +from pydantic import ValidationError +from sse_starlette.sse import EventSourceResponse +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import Route + from a2a.server.request_handlers.jsonrpc_handler import ( JSONRPCHandler, RequestHandler, @@ -29,12 +37,6 @@ UnsupportedOperationError, ) from a2a.utils.errors import MethodNotImplementedError -from pydantic import ValidationError -from sse_starlette.sse import EventSourceResponse -from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import JSONResponse, Response -from starlette.routing import Route logger = logging.getLogger(__name__) @@ -116,7 +118,7 @@ async def _handle_requests(self, request: Request) -> Response: return await self._process_non_streaming_request( request_id, a2a_request ) - except MethodNotImplementedError as e: + except MethodNotImplementedError: traceback.print_exc() return self._generate_error_response( request_id, A2AError(root=UnsupportedOperationError()) diff --git a/src/a2a/server/events/__init__.py b/src/a2a/server/events/__init__.py index 61d87068..d9662d04 100644 --- a/src/a2a/server/events/__init__.py +++ b/src/a2a/server/events/__init__.py @@ -1,14 +1,19 @@ from a2a.server.events.event_consumer import EventConsumer from a2a.server.events.event_queue import Event, EventQueue -from a2a.server.events.queue_manager import QueueManager, TaskQueueExists, NoTaskQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager +from a2a.server.events.queue_manager import ( + NoTaskQueue, + QueueManager, + TaskQueueExists, +) + __all__ = [ 'Event', 'EventConsumer', 'EventQueue', + 'InMemoryQueueManager', + 'NoTaskQueue', 'QueueManager', 'TaskQueueExists', - 'NoTaskQueue', - 'InMemoryQueueManager', ] diff --git a/src/a2a/server/events/in_memory_queue_manager.py b/src/a2a/server/events/in_memory_queue_manager.py index 586f4ea9..a0d95f8e 100644 --- a/src/a2a/server/events/in_memory_queue_manager.py +++ b/src/a2a/server/events/in_memory_queue_manager.py @@ -1,6 +1,12 @@ import asyncio + from a2a.server.events.event_queue import EventQueue -from a2a.server.events.queue_manager import QueueManager, TaskQueueExists, NoTaskQueue +from a2a.server.events.queue_manager import ( + NoTaskQueue, + QueueManager, + TaskQueueExists, +) + class InMemoryQueueManager(QueueManager): """InMemoryQueueManager is used for a single binary management. diff --git a/src/a2a/server/events/queue_manager.py b/src/a2a/server/events/queue_manager.py index 5dae47c2..bef273f3 100644 --- a/src/a2a/server/events/queue_manager.py +++ b/src/a2a/server/events/queue_manager.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod + from a2a.server.events.event_queue import EventQueue @@ -29,5 +30,6 @@ async def create_or_tap(self, task_id: str) -> EventQueue: class TaskQueueExists(Exception): pass + class NoTaskQueue(Exception): pass diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 6b271720..9b3851ce 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -1,8 +1,8 @@ from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) -from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler +from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import ( build_error_response, prepare_response_object, @@ -10,9 +10,9 @@ __all__ = [ - 'RequestHandler', 'DefaultRequestHandler', 'JSONRPCHandler', + 'RequestHandler', 'build_error_response', 'prepare_response_object', ] diff --git a/src/a2a/server/tasks/__init__.py b/src/a2a/server/tasks/__init__.py index abdd4bbb..d61df11f 100644 --- a/src/a2a/server/tasks/__init__.py +++ b/src/a2a/server/tasks/__init__.py @@ -4,4 +4,11 @@ from a2a.server.tasks.task_store import TaskStore from a2a.server.tasks.task_updater import TaskUpdater -__all__ = ['InMemoryTaskStore', 'TaskManager', 'TaskStore', 'ResultAggregator', 'TaskUpdater'] + +__all__ = [ + 'InMemoryTaskStore', + 'ResultAggregator', + 'TaskManager', + 'TaskStore', + 'TaskUpdater', +] diff --git a/src/a2a/server/tasks/task_manager.py b/src/a2a/server/tasks/task_manager.py index a8d4fbd2..a9978bce 100644 --- a/src/a2a/server/tasks/task_manager.py +++ b/src/a2a/server/tasks/task_manager.py @@ -1,18 +1,18 @@ import logging -from a2a.server.tasks.task_store import TaskStore from a2a.server.events.event_queue import Event -from a2a.utils.errors import ServerError +from a2a.server.tasks.task_store import TaskStore from a2a.types import ( + InvalidParamsError, Message, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus, TaskStatusUpdateEvent, - InvalidParamsError, ) from a2a.utils import append_artifact_to_task +from a2a.utils.errors import ServerError logger = logging.getLogger(__name__) @@ -70,7 +70,7 @@ async def save_task_event( message=f"Task in event doesn't match TaskManager {self.task_id} : {task_id_from_event}" ) ) - elif not self.task_id: + if not self.task_id: self.task_id = task_id_from_event if not self.context_id and self.context_id != event.contextId: self.context_id = event.contextId diff --git a/src/a2a/server/tasks/task_updater.py b/src/a2a/server/tasks/task_updater.py index 5c15d0e4..39751c92 100644 --- a/src/a2a/server/tasks/task_updater.py +++ b/src/a2a/server/tasks/task_updater.py @@ -1,8 +1,16 @@ import uuid from a2a.server.events import EventQueue -from a2a.types import (Artifact, Message, Part, Role, TaskArtifactUpdateEvent, - TaskState, TaskStatus, TaskStatusUpdateEvent) +from a2a.types import ( + Artifact, + Message, + Part, + Role, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) class TaskUpdater: diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index dfc4d09d..42e5d37e 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -11,6 +11,7 @@ ) from a2a.utils.task import new_task + __all__ = [ 'append_artifact_to_task', 'build_text_artifact', diff --git a/src/a2a/utils/task.py b/src/a2a/utils/task.py index c12f29c7..cd0da7e4 100644 --- a/src/a2a/utils/task.py +++ b/src/a2a/utils/task.py @@ -1,8 +1,6 @@ import uuid -from a2a.types import (Artifact, Message, Part, Role, Task, - TaskArtifactUpdateEvent, TaskState, TaskStatus, - TaskStatusUpdateEvent) +from a2a.types import Message, Task, TaskState, TaskStatus def new_task(request: Message) -> Task: diff --git a/tests/server/events/test_event_consumer.py b/tests/server/events/test_event_consumer.py index 9ee29f92..56db30a1 100644 --- a/tests/server/events/test_event_consumer.py +++ b/tests/server/events/test_event_consumer.py @@ -1,25 +1,27 @@ import asyncio + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + import pytest -from unittest.mock import AsyncMock, MagicMock, patch + from a2a.server.events.event_consumer import EventConsumer from a2a.server.events.event_queue import EventQueue -from a2a.server.tasks.task_manager import TaskManager -from a2a.utils.errors import ServerError from a2a.types import ( A2AError, Artifact, InternalError, + JSONRPCError, Message, + Part, Task, - TaskStatusUpdateEvent, TaskArtifactUpdateEvent, - TaskStatus, - JSONRPCError, TaskState, - Part, + TaskStatus, + TaskStatusUpdateEvent, TextPart, ) -from typing import Any +from a2a.utils.errors import ServerError MINIMAL_TASK: dict[str, Any] = {