-
Notifications
You must be signed in to change notification settings - Fork 303
refactor!: upgrade SDK to A2A 1.0 specs #572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: 1.0-a2a_proto_refactor
Are you sure you want to change the base?
Changes from 8 commits
b348735
74c5a19
2d698df
424dd7e
7405dc7
6462801
42c72f2
ac1050d
601ef0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,14 +3,7 @@ | |
|
|
||
| from a2a.client.auth.credentials import CredentialService | ||
| from a2a.client.middleware import ClientCallContext, ClientCallInterceptor | ||
| from a2a.types import ( | ||
| AgentCard, | ||
| APIKeySecurityScheme, | ||
| HTTPAuthSecurityScheme, | ||
| In, | ||
| OAuth2SecurityScheme, | ||
| OpenIdConnectSecurityScheme, | ||
| ) | ||
| from a2a.types.a2a_pb2 import AgentCard | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -35,63 +28,64 @@ async def intercept( | |
| """Applies authentication headers to the request if credentials are available.""" | ||
| if ( | ||
| agent_card is None | ||
| or agent_card.security is None | ||
| or agent_card.security_schemes is None | ||
| or not agent_card.security | ||
| or not agent_card.security_schemes | ||
| ): | ||
| return request_payload, http_kwargs | ||
|
|
||
| for requirement in agent_card.security: | ||
| for scheme_name in requirement: | ||
| for scheme_name in requirement.schemes: | ||
| credential = await self._credential_service.get_credentials( | ||
| scheme_name, context | ||
| ) | ||
| if credential and scheme_name in agent_card.security_schemes: | ||
| scheme_def_union = agent_card.security_schemes.get( | ||
| scheme_name | ||
| ) | ||
| if not scheme_def_union: | ||
| scheme = agent_card.security_schemes.get(scheme_name) | ||
| if not scheme: | ||
| continue | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Personal preference would be to use the functions available from proto to do the matching below instead of ending up with a non-statically typed function here. E.g. match scheme.WhichOneof('scheme'):
case 'http_auth_security_scheme' if scheme.http_auth_security_scheme.lower() == 'bearer':
...tbh I think it would read better as a set of scheme = agent_card.security_schemes.get(scheme_name)
if scheme.HasField('http_auth_security_scheme') and scheme.http_auth_security_scheme.lower() == 'bearer':
...
if scheme.HasField('oauth2_security_scheme') or scheme.HasField('open_id_connect_security_scheme'):
...
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure how's HasField improves things here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As explained our call, its mostly about readability, but also prevent us using reflection to check the python types, and instead we can just rely on the boolean check of "does the request have this field set" then it can use it.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check now @Tehsmash There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't seem to be able to resolve my own threads on the a2a repos, but I'll leave a LGTM if I think it should be resolved. |
||
| scheme_def = scheme_def_union.root | ||
|
|
||
| headers = http_kwargs.get('headers', {}) | ||
|
|
||
| match scheme_def: | ||
| # Case 1a: HTTP Bearer scheme with an if guard | ||
| case HTTPAuthSecurityScheme() if ( | ||
| scheme_def.scheme.lower() == 'bearer' | ||
| ): | ||
| headers['Authorization'] = f'Bearer {credential}' | ||
| logger.debug( | ||
| "Added Bearer token for scheme '%s' (type: %s).", | ||
| scheme_name, | ||
| scheme_def.type, | ||
| ) | ||
| http_kwargs['headers'] = headers | ||
| return request_payload, http_kwargs | ||
| # HTTP Bearer authentication | ||
| if ( | ||
| scheme.HasField('http_auth_security_scheme') | ||
| and scheme.http_auth_security_scheme.scheme.lower() | ||
| == 'bearer' | ||
| ): | ||
| headers['Authorization'] = f'Bearer {credential}' | ||
| logger.debug( | ||
| "Added Bearer token for scheme '%s'.", | ||
| scheme_name, | ||
| ) | ||
| http_kwargs['headers'] = headers | ||
| return request_payload, http_kwargs | ||
|
|
||
| # Case 1b: OAuth2 and OIDC schemes, which are implicitly Bearer | ||
| case ( | ||
| OAuth2SecurityScheme() | ||
| | OpenIdConnectSecurityScheme() | ||
| ): | ||
| headers['Authorization'] = f'Bearer {credential}' | ||
| logger.debug( | ||
| "Added Bearer token for scheme '%s' (type: %s).", | ||
| scheme_name, | ||
| scheme_def.type, | ||
| ) | ||
| http_kwargs['headers'] = headers | ||
| return request_payload, http_kwargs | ||
| # OAuth2 and OIDC schemes are implicitly Bearer | ||
| if scheme.HasField( | ||
| 'oauth2_security_scheme' | ||
| ) or scheme.HasField('open_id_connect_security_scheme'): | ||
| headers['Authorization'] = f'Bearer {credential}' | ||
| logger.debug( | ||
| "Added Bearer token for scheme '%s'.", | ||
| scheme_name, | ||
| ) | ||
| http_kwargs['headers'] = headers | ||
| return request_payload, http_kwargs | ||
|
|
||
| # Case 2: API Key in Header | ||
| case APIKeySecurityScheme(in_=In.header): | ||
| headers[scheme_def.name] = credential | ||
| logger.debug( | ||
| "Added API Key Header for scheme '%s'.", | ||
| scheme_name, | ||
| ) | ||
| http_kwargs['headers'] = headers | ||
| return request_payload, http_kwargs | ||
| # API Key in Header | ||
| if ( | ||
| scheme.HasField('api_key_security_scheme') | ||
| and scheme.api_key_security_scheme.location.lower() | ||
| == 'header' | ||
| ): | ||
| headers[scheme.api_key_security_scheme.name] = ( | ||
| credential | ||
| ) | ||
| logger.debug( | ||
| "Added API Key Header for scheme '%s'.", | ||
| scheme_name, | ||
| ) | ||
| http_kwargs['headers'] = headers | ||
| return request_payload, http_kwargs | ||
|
|
||
| # Note: Other cases like API keys in query/cookie are not handled and will be skipped. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Double checking the protobuf docs https://protobuf.dev/reference/python/python-generated/#embedded_message it sounds like we may need to use
agent_card.HasField("...")this separates the difference between agent_card.security set, not set and set but empty from what I can tell.