diff --git a/.github/actions/spelling/expect.txt b/.github/actions/spelling/expect.txt new file mode 100644 index 00000000..9e923c35 --- /dev/null +++ b/.github/actions/spelling/expect.txt @@ -0,0 +1,4 @@ +excinfo +GVsb +notif +otherurl diff --git a/examples/google_adk/birthday_planner/__main__.py b/examples/google_adk/birthday_planner/__main__.py index 02005e92..32a9cf12 100644 --- a/examples/google_adk/birthday_planner/__main__.py +++ b/examples/google_adk/birthday_planner/__main__.py @@ -6,18 +6,13 @@ import click import uvicorn -from adk_agent_executor import ADKAgentExecutor # type: ignore[import-untyped] +from adk_agent_executor import ADKAgentExecutor # type: ignore[import-untyped] from dotenv import load_dotenv from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore -from a2a.types import ( - AgentAuthentication, - AgentCapabilities, - AgentCard, - AgentSkill, -) +from a2a.types import AgentCapabilities, AgentCard, AgentSkill load_dotenv() @@ -69,7 +64,6 @@ def main(host: str, port: int, calendar_agent: str): defaultOutputModes=['text'], capabilities=AgentCapabilities(streaming=True), skills=[skill], - authentication=AgentAuthentication(schemes=['public']), ) request_handler = DefaultRequestHandler( agent_executor=agent_executor, task_store=InMemoryTaskStore() diff --git a/examples/google_adk/calendar_agent/__main__.py b/examples/google_adk/calendar_agent/__main__.py index 82bc2044..448e7ba5 100644 --- a/examples/google_adk/calendar_agent/__main__.py +++ b/examples/google_adk/calendar_agent/__main__.py @@ -4,13 +4,19 @@ import click import uvicorn -from adk_agent import create_agent # type: ignore[import-not-found] -from adk_agent_executor import ADKAgentExecutor # type: ignore[import-untyped] +from adk_agent import create_agent # type: ignore[import-not-found] +from adk_agent_executor import ADKAgentExecutor # type: ignore[import-untyped] from dotenv import load_dotenv -from google.adk.artifacts import InMemoryArtifactService # type: ignore[import-untyped] -from google.adk.memory.in_memory_memory_service import InMemoryMemoryService # type: ignore[import-untyped] -from google.adk.runners import Runner # type: ignore[import-untyped] -from google.adk.sessions import InMemorySessionService # type: ignore[import-untyped] +from google.adk.artifacts import ( + InMemoryArtifactService, # type: ignore[import-untyped] +) +from google.adk.memory.in_memory_memory_service import ( + InMemoryMemoryService, # type: ignore[import-untyped] +) +from google.adk.runners import Runner # type: ignore[import-untyped] +from google.adk.sessions import ( + InMemorySessionService, # type: ignore[import-untyped] +) from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import PlainTextResponse @@ -19,12 +25,7 @@ from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore -from a2a.types import ( - AgentAuthentication, - AgentCapabilities, - AgentCard, - AgentSkill, -) +from a2a.types import AgentCapabilities, AgentCard, AgentSkill load_dotenv() @@ -63,7 +64,6 @@ def main(host: str, port: int): defaultOutputModes=['text'], capabilities=AgentCapabilities(streaming=True), skills=[skill], - authentication=AgentAuthentication(schemes=['public']), ) adk_agent = create_agent( diff --git a/examples/helloworld/__main__.py b/examples/helloworld/__main__.py index cc386ebd..dfd9818f 100644 --- a/examples/helloworld/__main__.py +++ b/examples/helloworld/__main__.py @@ -1,14 +1,11 @@ -from agent_executor import HelloWorldAgentExecutor # type: ignore[import-untyped] +from agent_executor import ( + HelloWorldAgentExecutor, # type: ignore[import-untyped] +) from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore -from a2a.types import ( - AgentAuthentication, - AgentCapabilities, - AgentCard, - AgentSkill, -) +from a2a.types import AgentCapabilities, AgentCard, AgentSkill if __name__ == '__main__': @@ -29,7 +26,6 @@ defaultOutputModes=['text'], capabilities=AgentCapabilities(streaming=True), skills=[skill], - authentication=AgentAuthentication(schemes=['public']), ) request_handler = DefaultRequestHandler( diff --git a/examples/langgraph/__main__.py b/examples/langgraph/__main__.py index 783bd8e3..52e49fdd 100644 --- a/examples/langgraph/__main__.py +++ b/examples/langgraph/__main__.py @@ -4,19 +4,14 @@ import click import httpx -from agent import CurrencyAgent # type: ignore[import-untyped] -from agent_executor import CurrencyAgentExecutor # type: ignore[import-untyped] +from agent import CurrencyAgent # type: ignore[import-untyped] +from agent_executor import CurrencyAgentExecutor # type: ignore[import-untyped] from dotenv import load_dotenv from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryPushNotifier, InMemoryTaskStore -from a2a.types import ( - AgentAuthentication, - AgentCapabilities, - AgentCard, - AgentSkill, -) +from a2a.types import AgentCapabilities, AgentCard, AgentSkill load_dotenv() @@ -64,7 +59,6 @@ def get_agent_card(host: str, port: int): defaultOutputModes=CurrencyAgent.SUPPORTED_CONTENT_TYPES, capabilities=capabilities, skills=[skill], - authentication=AgentAuthentication(schemes=['public']), ) diff --git a/src/a2a/types.py b/src/a2a/types.py index 3587d441..a171eabc 100644 --- a/src/a2a/types.py +++ b/src/a2a/types.py @@ -6,27 +6,41 @@ from enum import Enum from typing import Any, Literal -from pydantic import BaseModel, RootModel +from pydantic import BaseModel, Field, RootModel class A2A(RootModel[Any]): root: Any -class AgentAuthentication(BaseModel): +class In(Enum): """ - Defines authentication requirements for an agent. - Intended to match OpenAPI authentication structure. + The location of the API key. Valid values are "query", "header", or "cookie". """ - credentials: str | None = None + cookie = 'cookie' + header = 'header' + query = 'query' + + +class APIKeySecurityScheme(BaseModel): """ - credentials a client should use for private cards + API Key security scheme. """ - schemes: list[str] + + description: str | None = None + """ + description of this security scheme """ - e.g. Basic, Bearer + in_: In = Field(..., alias='in') """ + The location of the API key. Valid values are "query", "header", or "cookie". + """ + name: str + """ + The name of the header, query or cookie parameter to be used. + """ + type: Literal['apiKey'] = 'apiKey' class AgentCapabilities(BaseModel): @@ -193,6 +207,30 @@ class FileWithUri(BaseModel): uri: str +class HTTPAuthSecurityScheme(BaseModel): + """ + HTTP Authentication security scheme. + """ + + bearerFormat: str | None = None + """ + A hint to the client to identify how the bearer token is formatted. Bearer tokens are usually + generated by an authorization server, so this information is primarily for documentation + purposes. + """ + description: str | None = None + """ + description of this security scheme + """ + scheme: str + """ + The name of the HTTP Authentication scheme to be used in the Authorization header as defined + in RFC7235. The values used SHOULD be registered in the IANA Authentication Scheme registry. + The value is case-insensitive, as defined in RFC7235. + """ + type: Literal['http'] = 'http' + + class InternalError(BaseModel): """ JSON-RPC error indicating an internal JSON-RPC error on the server. @@ -403,6 +441,72 @@ class MethodNotFoundError(BaseModel): """ +class OAuthFlow(BaseModel): + """ + Configuration details for a supported OAuth Flow + """ + + authorizationUrl: str + """ + The authorization URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 + standard requires the use of TLS + """ + refreshUrl: str + """ + The URL to be used for obtaining refresh tokens. This MUST be in the form of a URL. The OAuth2 + standard requires the use of TLS. + """ + scopes: dict[str, str] + """ + The available scopes for the OAuth2 security scheme. A map between the scope name and a short + description for it. The map MAY be empty. + """ + tokenUrl: str + """ + The token URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard + requires the use of TLS. + """ + + +class OAuthFlows(BaseModel): + """ + Allows configuration of the supported OAuth Flows + """ + + authorizationCode: OAuthFlow | None = None + """ + Configuration for the OAuth Authorization Code flow. Previously called accessCode in OpenAPI 2.0. + """ + clientCredentials: OAuthFlow | None = None + """ + Configuration for the OAuth Client Credentials flow. Previously called application in OpenAPI 2.0 + """ + implicit: OAuthFlow | None = None + """ + Configuration for the OAuth Implicit flow + """ + password: OAuthFlow | None = None + """ + Configuration for the OAuth Resource Owner Password flow + """ + + +class OpenIdConnectSecurityScheme(BaseModel): + """ + OpenID Connect security scheme configuration. + """ + + description: str | None = None + """ + description of this security scheme + """ + openIdConnectUrl: str + """ + Well-known URL to discover the [[OpenID-Connect-Discovery]] provider metadata. + """ + type: Literal['openIdConnect'] = 'openIdConnect' + + class PartBase(BaseModel): """ Base properties common to all message parts. @@ -465,6 +569,17 @@ class PushNotificationNotSupportedError(BaseModel): """ +class SecuritySchemeBase(BaseModel): + """ + Base properties shared by all security schemes. + """ + + description: str | None = None + """ + description of this security scheme + """ + + class TaskIdParams(BaseModel): """ Parameters containing only a task ID, used for simple task operations. @@ -654,64 +769,6 @@ class A2AError( ) -class AgentCard(BaseModel): - """ - An AgentCard conveys key information: - - Overall details (version, name, description, uses) - - Skills: A set of capabilities the agent can perform - - Default modalities/content types supported by the agent. - - Authentication requirements - """ - - authentication: AgentAuthentication - """ - Authentication requirements for the agent. - """ - capabilities: AgentCapabilities - """ - Optional capabilities supported by the agent. - """ - defaultInputModes: list[str] - """ - The set of interaction modes that the agent - supports across all skills. This can be overridden per-skill. - Supported mime types for input. - """ - defaultOutputModes: list[str] - """ - Supported mime types for output. - """ - description: str - """ - A human-readable description of the agent. Used to assist users and - other agents in understanding what the agent can do. - """ - documentationUrl: str | None = None - """ - A URL to documentation for the agent. - """ - name: str - """ - Human readable name of the agent. - """ - provider: AgentProvider | None = None - """ - The service provider of the agent - """ - skills: list[AgentSkill] - """ - Skills are a unit of capability that an agent can perform. - """ - url: str - """ - A URL to the address the agent is hosted at. - """ - version: str - """ - The version of the agent - format is up to the provider. - """ - - class CancelTaskRequest(BaseModel): """ JSON-RPC request model for the 'tasks/cancel' method. @@ -878,6 +935,22 @@ class MessageSendConfiguration(BaseModel): """ +class OAuth2SecurityScheme(BaseModel): + """ + OAuth2.0 security scheme configuration. + """ + + description: str | None = None + """ + description of this security scheme + """ + flows: OAuthFlows + """ + An object containing configuration information for the flow types supported. + """ + type: Literal['oauth2'] = 'oauth2' + + class Part(RootModel[TextPart | FilePart | DataPart]): root: TextPart | FilePart | DataPart """ @@ -885,6 +958,26 @@ class Part(RootModel[TextPart | FilePart | DataPart]): """ +class SecurityScheme( + RootModel[ + APIKeySecurityScheme + | HTTPAuthSecurityScheme + | OAuth2SecurityScheme + | OpenIdConnectSecurityScheme + ] +): + root: ( + APIKeySecurityScheme + | HTTPAuthSecurityScheme + | OAuth2SecurityScheme + | OpenIdConnectSecurityScheme + ) + """ + Mirrors the OpenAPI Security Scheme Object + (https://swagger.io/specification/#security-scheme-object) + """ + + class SetTaskPushNotificationConfigRequest(BaseModel): """ JSON-RPC request model for the 'tasks/pushNotificationConfig/set' method. @@ -931,6 +1024,68 @@ class SetTaskPushNotificationConfigSuccessResponse(BaseModel): """ +class AgentCard(BaseModel): + """ + An AgentCard conveys key information: + - Overall details (version, name, description, uses) + - Skills: A set of capabilities the agent can perform + - Default modalities/content types supported by the agent. + - Authentication requirements + """ + + capabilities: AgentCapabilities + """ + Optional capabilities supported by the agent. + """ + defaultInputModes: list[str] + """ + The set of interaction modes that the agent + supports across all skills. This can be overridden per-skill. + Supported mime types for input. + """ + defaultOutputModes: list[str] + """ + Supported mime types for output. + """ + description: str + """ + A human-readable description of the agent. Used to assist users and + other agents in understanding what the agent can do. + """ + documentationUrl: str | None = None + """ + A URL to documentation for the agent. + """ + name: str + """ + Human readable name of the agent. + """ + provider: AgentProvider | None = None + """ + The service provider of the agent + """ + security: list[dict[str, list[str]]] | None = None + """ + Security requirements for contacting the agent. + """ + securitySchemes: dict[str, SecurityScheme] | None = None + """ + Security scheme details used for authenticating with this agent. + """ + skills: list[AgentSkill] + """ + Skills are a unit of capability that an agent can perform. + """ + url: str + """ + A URL to the address the agent is hosted at. + """ + version: str + """ + The version of the agent - format is up to the provider. + """ + + class Artifact(BaseModel): """ Represents an artifact generated for a task task. @@ -959,7 +1114,9 @@ class Artifact(BaseModel): class GetTaskPushNotificationConfigResponse( - RootModel[JSONRPCErrorResponse | GetTaskPushNotificationConfigSuccessResponse] + RootModel[ + JSONRPCErrorResponse | GetTaskPushNotificationConfigSuccessResponse + ] ): root: JSONRPCErrorResponse | GetTaskPushNotificationConfigSuccessResponse """ @@ -1070,7 +1227,9 @@ class SendStreamingMessageRequest(BaseModel): class SetTaskPushNotificationConfigResponse( - RootModel[JSONRPCErrorResponse | SetTaskPushNotificationConfigSuccessResponse] + RootModel[ + JSONRPCErrorResponse | SetTaskPushNotificationConfigSuccessResponse + ] ): root: JSONRPCErrorResponse | SetTaskPushNotificationConfigSuccessResponse """ @@ -1293,7 +1452,9 @@ class SendStreamingMessageSuccessResponse(BaseModel): """ -class CancelTaskResponse(RootModel[JSONRPCErrorResponse | CancelTaskSuccessResponse]): +class CancelTaskResponse( + RootModel[JSONRPCErrorResponse | CancelTaskSuccessResponse] +): root: JSONRPCErrorResponse | CancelTaskSuccessResponse """ JSON-RPC response for the 'tasks/cancel' method. @@ -1332,7 +1493,9 @@ class JSONRPCResponse( """ -class SendMessageResponse(RootModel[JSONRPCErrorResponse | SendMessageSuccessResponse]): +class SendMessageResponse( + RootModel[JSONRPCErrorResponse | SendMessageSuccessResponse] +): root: JSONRPCErrorResponse | SendMessageSuccessResponse """ JSON-RPC response model for the 'message/send' method. diff --git a/tests/client/test_client.py b/tests/client/test_client.py index efb6ca12..7c7926ce 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,9 +1,14 @@ import json -from unittest.mock import AsyncMock, MagicMock, patch + from collections.abc import AsyncGenerator +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + import httpx import pytest +from httpx_sse import EventSource, ServerSentEvent + from a2a.client import ( A2ACardResolver, A2AClient, @@ -12,31 +17,28 @@ create_text_message_object, ) from a2a.types import ( + A2ARequest, + AgentCapabilities, AgentCard, AgentSkill, - AgentCapabilities, - AgentAuthentication, - A2ARequest, - Role, - TaskQueryParams, - TaskIdParams, + CancelTaskRequest, + CancelTaskResponse, + CancelTaskSuccessResponse, GetTaskRequest, GetTaskResponse, - SendMessageRequest, + InvalidParamsError, + JSONRPCErrorResponse, MessageSendParams, + Role, + SendMessageRequest, SendMessageResponse, SendMessageSuccessResponse, - JSONRPCErrorResponse, - InvalidParamsError, - CancelTaskRequest, - CancelTaskResponse, - CancelTaskSuccessResponse, SendStreamingMessageRequest, SendStreamingMessageResponse, + TaskIdParams, TaskNotCancelableError, + TaskQueryParams, ) -from typing import Any -from httpx_sse import ServerSentEvent, EventSource AGENT_CARD = AgentCard( @@ -56,7 +58,6 @@ examples=['hi', 'hello world'], ) ], - authentication=AgentAuthentication(schemes=['public']), ) MINIMAL_TASK: dict[str, Any] = { @@ -86,7 +87,7 @@ def mock_agent_card() -> MagicMock: async def async_iterable_from_list( items: list[ServerSentEvent], -) -> AsyncGenerator[ServerSentEvent, None]: +) -> AsyncGenerator[ServerSentEvent]: """Helper to create an async iterable from a list.""" for item in items: yield item diff --git a/tests/test_types.py b/tests/test_types.py index ef658a19..f34a14d5 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -7,7 +7,7 @@ from a2a.types import ( A2AError, A2ARequest, - AgentAuthentication, + APIKeySecurityScheme, AgentCapabilities, AgentCard, AgentProvider, @@ -15,6 +15,7 @@ Artifact, CancelTaskRequest, CancelTaskResponse, + CancelTaskSuccessResponse, ContentTypeNotSupportedError, DataPart, FileBase, @@ -23,8 +24,11 @@ FileWithUri, GetTaskPushNotificationConfigRequest, GetTaskPushNotificationConfigResponse, + GetTaskPushNotificationConfigSuccessResponse, GetTaskRequest, GetTaskResponse, + GetTaskSuccessResponse, + In, InternalError, InvalidParamsError, InvalidRequestError, @@ -35,18 +39,25 @@ JSONRPCRequest, JSONRPCResponse, Message, + MessageSendParams, MethodNotFoundError, + OAuth2SecurityScheme, Part, PartBase, PushNotificationAuthenticationInfo, PushNotificationConfig, PushNotificationNotSupportedError, + Role, + SecurityScheme, SendMessageRequest, SendMessageResponse, + SendMessageSuccessResponse, SendStreamingMessageRequest, SendStreamingMessageResponse, + SendStreamingMessageSuccessResponse, SetTaskPushNotificationConfigRequest, SetTaskPushNotificationConfigResponse, + SetTaskPushNotificationConfigSuccessResponse, Task, TaskArtifactUpdateEvent, TaskIdParams, @@ -55,27 +66,20 @@ TaskPushNotificationConfig, TaskQueryParams, TaskResubscriptionRequest, - MessageSendParams, TaskState, TaskStatus, TaskStatusUpdateEvent, TextPart, UnsupportedOperationError, - GetTaskSuccessResponse, - SendStreamingMessageSuccessResponse, - SendMessageSuccessResponse, - CancelTaskSuccessResponse, - Role, - SetTaskPushNotificationConfigSuccessResponse, - GetTaskPushNotificationConfigSuccessResponse, ) + # --- Helper Data --- -MINIMAL_AGENT_AUTH: dict[str, Any] = {'schemes': ['Bearer']} -FULL_AGENT_AUTH: dict[str, Any] = { - 'schemes': ['Bearer', 'Basic'], - 'credentials': 'user:pass', +MINIMAL_AGENT_SECURITY_SCHEME: dict[str, Any] = { + 'type': 'apiKey', + 'in': 'header', + 'name': 'X-API-KEY', } MINIMAL_AGENT_SKILL: dict[str, Any] = { @@ -95,7 +99,6 @@ } MINIMAL_AGENT_CARD: dict[str, Any] = { - 'authentication': MINIMAL_AGENT_AUTH, 'capabilities': {}, # AgentCapabilities is required but can be empty 'defaultInputModes': ['text/plain'], 'defaultOutputModes': ['application/json'], @@ -175,26 +178,23 @@ # --- Test Functions --- -def test_agent_authentication_valid(): - auth = AgentAuthentication(**MINIMAL_AGENT_AUTH) - assert auth.schemes == ['Bearer'] - assert auth.credentials is None - - auth_full = AgentAuthentication(**FULL_AGENT_AUTH) - assert auth_full.schemes == ['Bearer', 'Basic'] - assert auth_full.credentials == 'user:pass' +def test_security_scheme_valid(): + scheme = SecurityScheme.model_validate(MINIMAL_AGENT_SECURITY_SCHEME) + assert isinstance(scheme.root, APIKeySecurityScheme) + assert scheme.root.type == 'apiKey' + assert scheme.root.in_ == In.header + assert scheme.root.name == 'X-API-KEY' -def test_agent_authentication_invalid(): +def test_security_scheme_invalid(): with pytest.raises(ValidationError): - AgentAuthentication( - credentials='only_creds' - ) # Missing schemes # type: ignore + APIKeySecurityScheme( + name='my_api_key', + ) # Missing "in" # type: ignore - AgentAuthentication( - schemes=['Bearer'], - extra_field='extra', # type: ignore - ) # Extra field + OAuth2SecurityScheme( + description='OAuth2 scheme missing flows', + ) # Missing "flows" def test_agent_capabilities(): @@ -251,7 +251,6 @@ def test_agent_card_valid(): card = AgentCard(**MINIMAL_AGENT_CARD) assert card.name == 'TestAgent' assert card.version == '1.0' - assert card.authentication.schemes == ['Bearer'] assert len(card.skills) == 1 assert card.skills[0].id == 'skill-123' assert card.provider is None # Optional