Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions buf.gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ plugins:
# Generate python protobuf related code
# Generates *_pb2.py files, one for each .proto
- remote: buf.build/protocolbuffers/python:v29.3
out: src/a2a/grpc
out: src/a2a/types
# Generate python service code.
# Generates *_pb2_grpc.py
- remote: buf.build/grpc/python
out: src/a2a/grpc
out: src/a2a/types
# Generates *_pb2.pyi files.
- remote: buf.build/protocolbuffers/pyi
out: src/a2a/grpc
out: src/a2a/types
28 changes: 23 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ dependencies = [
"pydantic>=2.11.3",
"protobuf>=5.29.5",
"google-api-core>=1.26.0",
"json-rpc>=1.15.0",
"googleapis-common-protos>=1.70.0",
]

classifiers = [
Expand Down Expand Up @@ -74,6 +76,16 @@ addopts = "-ra --strict-markers"
markers = [
"asyncio: mark a test as a coroutine that should be run by pytest-asyncio",
]
filterwarnings = [
# SQLAlchemy warning about duplicate class registration - this is a known limitation
# of the dynamic model creation pattern used in models.py for custom table names
"ignore:This declarative base already contains a class with the same class name:sqlalchemy.exc.SAWarning",
# ResourceWarnings from asyncio event loop/socket cleanup during garbage collection
# These appear intermittently between tests due to pytest-asyncio and sse-starlette timing
"ignore:unclosed event loop:ResourceWarning",
"ignore:unclosed transport:ResourceWarning",
"ignore:unclosed <socket.socket:ResourceWarning",
]

[tool.pytest-asyncio]
mode = "strict"
Expand Down Expand Up @@ -114,7 +126,7 @@ explicit = true

[tool.mypy]
plugins = ["pydantic.mypy"]
exclude = ["src/a2a/grpc/"]
exclude = ["src/a2a/types/a2a_pb2\\.py", "src/a2a/types/a2a_pb2_grpc\\.py"]
disable_error_code = [
"import-not-found",
"annotation-unchecked",
Expand All @@ -134,7 +146,8 @@ exclude = [
"**/node_modules",
"**/venv",
"**/.venv",
"src/a2a/grpc/",
"src/a2a/types/a2a_pb2.py",
"src/a2a/types/a2a_pb2_grpc.py",
]
reportMissingImports = "none"
reportMissingModuleSource = "none"
Expand All @@ -145,7 +158,8 @@ omit = [
"*/tests/*",
"*/site-packages/*",
"*/__init__.py",
"src/a2a/grpc/*",
"src/a2a/types/a2a_pb2.py",
"src/a2a/types/a2a_pb2_grpc.py",
]

[tool.coverage.report]
Expand Down Expand Up @@ -257,7 +271,9 @@ exclude = [
"node_modules",
"venv",
"*/migrations/*",
"src/a2a/grpc/**",
"src/a2a/types/a2a_pb2.py",
"src/a2a/types/a2a_pb2.pyi",
"src/a2a/types/a2a_pb2_grpc.py",
"tests/**",
]

Expand Down Expand Up @@ -311,7 +327,9 @@ inline-quotes = "single"

[tool.ruff.format]
exclude = [
"src/a2a/grpc/**",
"src/a2a/types/a2a_pb2.py",
"src/a2a/types/a2a_pb2.pyi",
"src/a2a/types/a2a_pb2_grpc.py",
]
docstring-code-format = true
docstring-code-line-length = "dynamic"
Expand Down
21 changes: 0 additions & 21 deletions src/a2a/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,18 @@
A2AClientTimeoutError,
)
from a2a.client.helpers import create_text_message_object
from a2a.client.legacy import A2AClient
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor


logger = logging.getLogger(__name__)

try:
from a2a.client.legacy_grpc import A2AGrpcClient # type: ignore
except ImportError as e:
_original_error = e
logger.debug(
'A2AGrpcClient not loaded. This is expected if gRPC dependencies are not installed. Error: %s',
_original_error,
)

class A2AGrpcClient: # type: ignore
"""Placeholder for A2AGrpcClient when dependencies are not installed."""

def __init__(self, *args, **kwargs):
raise ImportError(
'To use A2AGrpcClient, its dependencies must be installed. '
'You can install them with \'pip install "a2a-sdk[grpc]"\''
) from _original_error


__all__ = [
'A2ACardResolver',
'A2AClient',
'A2AClientError',
'A2AClientHTTPError',
'A2AClientJSONError',
'A2AClientTimeoutError',
'A2AGrpcClient',
'AuthInterceptor',
'BaseClient',
'Client',
Expand Down
96 changes: 45 additions & 51 deletions src/a2a/client/auth/interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,7 @@

from a2a.client.auth.credentials import CredentialService
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.types import (
AgentCard,
APIKeySecurityScheme,
HTTPAuthSecurityScheme,
In,
OAuth2SecurityScheme,
OpenIdConnectSecurityScheme,
)
from a2a.types.a2a_pb2 import AgentCard

logger = logging.getLogger(__name__)

Expand All @@ -35,63 +28,64 @@ async def intercept(
"""Applies authentication headers to the request if credentials are available."""
if (
agent_card is None
or agent_card.security is None
or agent_card.security_schemes is None
or not agent_card.security
or not agent_card.security_schemes
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double checking the protobuf docs https://protobuf.dev/reference/python/python-generated/#embedded_message it sounds like we may need to use agent_card.HasField("...") this separates the difference between agent_card.security set, not set and set but empty from what I can tell.

):
return request_payload, http_kwargs

for requirement in agent_card.security:
for scheme_name in requirement:
for scheme_name in requirement.schemes:
credential = await self._credential_service.get_credentials(
scheme_name, context
)
if credential and scheme_name in agent_card.security_schemes:
scheme_def_union = agent_card.security_schemes.get(
scheme_name
)
if not scheme_def_union:
scheme = agent_card.security_schemes.get(scheme_name)
if not scheme:
continue
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personal preference would be to use the functions available from proto to do the matching below instead of ending up with a non-statically typed function here. E.g.

match scheme.WhichOneof('scheme'):
    case 'http_auth_security_scheme' if scheme.http_auth_security_scheme.lower() == 'bearer':
    	...

tbh I think it would read better as a set of if blocks:

scheme = agent_card.security_schemes.get(scheme_name)

if scheme.HasField('http_auth_security_scheme') and scheme.http_auth_security_scheme.lower() == 'bearer':
   ... 

if scheme.HasField('oauth2_security_scheme') or scheme.HasField('open_id_connect_security_scheme'):
   ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how's HasField improves things here.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As explained our call, its mostly about readability, but also prevent us using reflection to check the python types, and instead we can just rely on the boolean check of "does the request have this field set" then it can use it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check now @Tehsmash

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't seem to be able to resolve my own threads on the a2a repos, but I'll leave a LGTM if I think it should be resolved.

scheme_def = scheme_def_union.root

headers = http_kwargs.get('headers', {})

match scheme_def:
# Case 1a: HTTP Bearer scheme with an if guard
case HTTPAuthSecurityScheme() if (
scheme_def.scheme.lower() == 'bearer'
):
headers['Authorization'] = f'Bearer {credential}'
logger.debug(
"Added Bearer token for scheme '%s' (type: %s).",
scheme_name,
scheme_def.type,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs
# HTTP Bearer authentication
if (
scheme.HasField('http_auth_security_scheme')
and scheme.http_auth_security_scheme.scheme.lower()
== 'bearer'
):
headers['Authorization'] = f'Bearer {credential}'
logger.debug(
"Added Bearer token for scheme '%s'.",
scheme_name,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs

# Case 1b: OAuth2 and OIDC schemes, which are implicitly Bearer
case (
OAuth2SecurityScheme()
| OpenIdConnectSecurityScheme()
):
headers['Authorization'] = f'Bearer {credential}'
logger.debug(
"Added Bearer token for scheme '%s' (type: %s).",
scheme_name,
scheme_def.type,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs
# OAuth2 and OIDC schemes are implicitly Bearer
if scheme.HasField(
'oauth2_security_scheme'
) or scheme.HasField('open_id_connect_security_scheme'):
headers['Authorization'] = f'Bearer {credential}'
logger.debug(
"Added Bearer token for scheme '%s'.",
scheme_name,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs

# Case 2: API Key in Header
case APIKeySecurityScheme(in_=In.header):
headers[scheme_def.name] = credential
logger.debug(
"Added API Key Header for scheme '%s'.",
scheme_name,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs
# API Key in Header
if (
scheme.HasField('api_key_security_scheme')
and scheme.api_key_security_scheme.location.lower()
== 'header'
):
headers[scheme.api_key_security_scheme.name] = (
credential
)
logger.debug(
"Added API Key Header for scheme '%s'.",
scheme_name,
)
http_kwargs['headers'] = headers
return request_payload, http_kwargs

# Note: Other cases like API keys in query/cookie are not handled and will be skipped.

Expand Down
Loading