Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions buf.gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
version: v2
inputs:
- git_repo: https://github.com/a2aproject/A2A.git
ref: main
ref: transports
subdir: specification/grpc
managed:
enabled: true
Expand All @@ -21,11 +21,11 @@ plugins:
# Generate python protobuf related code
# Generates *_pb2.py files, one for each .proto
- remote: buf.build/protocolbuffers/python:v29.3
out: src/a2a/grpc
out: src/a2a/types
# Generate python service code.
# Generates *_pb2_grpc.py
- remote: buf.build/grpc/python
out: src/a2a/grpc
out: src/a2a/types
# Generates *_pb2.pyi files.
- remote: buf.build/protocolbuffers/pyi
out: src/a2a/grpc
out: src/a2a/types
17 changes: 12 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"pydantic>=2.11.3",
"protobuf>=5.29.5",
"google-api-core>=1.26.0",
"json-rpc>=1.15.0",
]

classifiers = [
Expand Down Expand Up @@ -114,7 +115,7 @@ explicit = true

[tool.mypy]
plugins = ["pydantic.mypy"]
exclude = ["src/a2a/grpc/"]
exclude = ["src/a2a/types/a2a_pb2\\.py", "src/a2a/types/a2a_pb2_grpc\\.py"]
disable_error_code = [
"import-not-found",
"annotation-unchecked",
Expand All @@ -134,7 +135,8 @@ exclude = [
"**/node_modules",
"**/venv",
"**/.venv",
"src/a2a/grpc/",
"src/a2a/types/a2a_pb2.py",
"src/a2a/types/a2a_pb2_grpc.py",
]
reportMissingImports = "none"
reportMissingModuleSource = "none"
Expand All @@ -145,7 +147,8 @@ omit = [
"*/tests/*",
"*/site-packages/*",
"*/__init__.py",
"src/a2a/grpc/*",
"src/a2a/types/a2a_pb2.py",
"src/a2a/types/a2a_pb2_grpc.py",
]

[tool.coverage.report]
Expand Down Expand Up @@ -257,7 +260,9 @@ exclude = [
"node_modules",
"venv",
"*/migrations/*",
"src/a2a/grpc/**",
"src/a2a/types/a2a_pb2.py",
"src/a2a/types/a2a_pb2.pyi",
"src/a2a/types/a2a_pb2_grpc.py",
"tests/**",
]

Expand Down Expand Up @@ -311,7 +316,9 @@ inline-quotes = "single"

[tool.ruff.format]
exclude = [
"src/a2a/grpc/**",
"src/a2a/types/a2a_pb2.py",
"src/a2a/types/a2a_pb2.pyi",
"src/a2a/types/a2a_pb2_grpc.py",
]
docstring-code-format = true
docstring-code-line-length = "dynamic"
Expand Down
21 changes: 0 additions & 21 deletions src/a2a/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,18 @@
A2AClientTimeoutError,
)
from a2a.client.helpers import create_text_message_object
from a2a.client.legacy import A2AClient
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor


logger = logging.getLogger(__name__)

try:
from a2a.client.legacy_grpc import A2AGrpcClient # type: ignore
except ImportError as e:
_original_error = e
logger.debug(
'A2AGrpcClient not loaded. This is expected if gRPC dependencies are not installed. Error: %s',
_original_error,
)

class A2AGrpcClient: # type: ignore
"""Placeholder for A2AGrpcClient when dependencies are not installed."""

def __init__(self, *args, **kwargs):
raise ImportError(
'To use A2AGrpcClient, its dependencies must be installed. '
'You can install them with \'pip install "a2a-sdk[grpc]"\''
) from _original_error


__all__ = [
'A2ACardResolver',
'A2AClient',
'A2AClientError',
'A2AClientHTTPError',
'A2AClientJSONError',
'A2AClientTimeoutError',
'A2AGrpcClient',
'AuthInterceptor',
'BaseClient',
'Client',
Expand Down
38 changes: 27 additions & 11 deletions src/a2a/client/auth/interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,34 @@

from a2a.client.auth.credentials import CredentialService
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.types import (
from a2a.types.a2a_pb2 import (
AgentCard,
APIKeySecurityScheme,
HTTPAuthSecurityScheme,
In,
OAuth2SecurityScheme,
OpenIdConnectSecurityScheme,
SecurityScheme,
)

logger = logging.getLogger(__name__)


def _get_security_scheme_value(scheme: SecurityScheme):
"""Extract the actual security scheme from the oneof union."""
which = scheme.WhichOneof('scheme')
if which == 'api_key_security_scheme':
return scheme.api_key_security_scheme
elif which == 'http_auth_security_scheme':
return scheme.http_auth_security_scheme
elif which == 'oauth2_security_scheme':
return scheme.oauth2_security_scheme
elif which == 'open_id_connect_security_scheme':
return scheme.open_id_connect_security_scheme
elif which == 'mtls_security_scheme':
return scheme.mtls_security_scheme
return None


class AuthInterceptor(ClientCallInterceptor):
"""An interceptor that automatically adds authentication details to requests.

Expand All @@ -35,13 +51,13 @@ async def intercept(
"""Applies authentication headers to the request if credentials are available."""
if (
agent_card is None
or agent_card.security is None
or agent_card.security_schemes is None
or not agent_card.security
or not agent_card.security_schemes
):
return request_payload, http_kwargs

for requirement in agent_card.security:
for scheme_name in requirement:
for scheme_name in requirement.schemes:
credential = await self._credential_service.get_credentials(
scheme_name, context
)
Expand All @@ -51,7 +67,9 @@ async def intercept(
)
if not scheme_def_union:
continue
scheme_def = scheme_def_union.root
scheme_def = _get_security_scheme_value(scheme_def_union)
if not scheme_def:
continue

headers = http_kwargs.get('headers', {})

Expand All @@ -62,9 +80,8 @@ async def intercept(
):
headers['Authorization'] = f'Bearer {credential}'
logger.debug(
"Added Bearer token for scheme '%s' (type: %s).",
"Added Bearer token for scheme '%s'.",
scheme_name,
scheme_def.type,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs
Expand All @@ -76,15 +93,14 @@ async def intercept(
):
headers['Authorization'] = f'Bearer {credential}'
logger.debug(
"Added Bearer token for scheme '%s' (type: %s).",
"Added Bearer token for scheme '%s'.",
scheme_name,
scheme_def.type,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs

# Case 2: API Key in Header
case APIKeySecurityScheme(in_=In.header):
case APIKeySecurityScheme() if scheme_def.location.lower() == 'header':
headers[scheme_def.name] = credential
logger.debug(
"Added API Key Header for scheme '%s'.",
Expand Down
Loading