-
Notifications
You must be signed in to change notification settings - Fork 303
refactor!: upgrade SDK to A2A 1.0 specs #572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: 1.0-a2a_proto_refactor
Are you sure you want to change the base?
Changes from 2 commits
b348735
74c5a19
2d698df
424dd7e
7405dc7
6462801
42c72f2
ac1050d
601ef0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,18 +3,34 @@ | |
|
|
||
| from a2a.client.auth.credentials import CredentialService | ||
| from a2a.client.middleware import ClientCallContext, ClientCallInterceptor | ||
| from a2a.types import ( | ||
| from a2a.types.a2a_pb2 import ( | ||
| AgentCard, | ||
| APIKeySecurityScheme, | ||
| HTTPAuthSecurityScheme, | ||
| In, | ||
| OAuth2SecurityScheme, | ||
| OpenIdConnectSecurityScheme, | ||
| SecurityScheme, | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _get_security_scheme_value(scheme: SecurityScheme): | ||
| """Extract the actual security scheme from the oneof union.""" | ||
| which = scheme.WhichOneof('scheme') | ||
| if which == 'api_key_security_scheme': | ||
| return scheme.api_key_security_scheme | ||
| elif which == 'http_auth_security_scheme': | ||
| return scheme.http_auth_security_scheme | ||
| elif which == 'oauth2_security_scheme': | ||
| return scheme.oauth2_security_scheme | ||
| elif which == 'open_id_connect_security_scheme': | ||
| return scheme.open_id_connect_security_scheme | ||
| elif which == 'mtls_security_scheme': | ||
| return scheme.mtls_security_scheme | ||
| return None | ||
|
||
|
|
||
|
|
||
| class AuthInterceptor(ClientCallInterceptor): | ||
| """An interceptor that automatically adds authentication details to requests. | ||
|
|
||
|
|
@@ -35,13 +51,13 @@ async def intercept( | |
| """Applies authentication headers to the request if credentials are available.""" | ||
| if ( | ||
| agent_card is None | ||
| or agent_card.security is None | ||
| or agent_card.security_schemes is None | ||
| or not agent_card.security | ||
| or not agent_card.security_schemes | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Double checking the protobuf docs https://protobuf.dev/reference/python/python-generated/#embedded_message it sounds like we may need to use |
||
| ): | ||
| return request_payload, http_kwargs | ||
|
|
||
| for requirement in agent_card.security: | ||
| for scheme_name in requirement: | ||
| for scheme_name in requirement.schemes: | ||
| credential = await self._credential_service.get_credentials( | ||
| scheme_name, context | ||
| ) | ||
|
|
@@ -51,7 +67,9 @@ async def intercept( | |
| ) | ||
| if not scheme_def_union: | ||
| continue | ||
| scheme_def = scheme_def_union.root | ||
| scheme_def = _get_security_scheme_value(scheme_def_union) | ||
| if not scheme_def: | ||
| continue | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Personal preference would be to use the functions available from proto to do the matching below instead of ending up with a non-statically typed function here. E.g. match scheme.WhichOneof('scheme'):
case 'http_auth_security_scheme' if scheme.http_auth_security_scheme.lower() == 'bearer':
...tbh I think it would read better as a set of scheme = agent_card.security_schemes.get(scheme_name)
if scheme.HasField('http_auth_security_scheme') and scheme.http_auth_security_scheme.lower() == 'bearer':
...
if scheme.HasField('oauth2_security_scheme') or scheme.HasField('open_id_connect_security_scheme'):
...
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure how's HasField improves things here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As explained our call, its mostly about readability, but also prevent us using reflection to check the python types, and instead we can just rely on the boolean check of "does the request have this field set" then it can use it.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check now @Tehsmash There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't seem to be able to resolve my own threads on the a2a repos, but I'll leave a LGTM if I think it should be resolved. |
||
|
|
||
| headers = http_kwargs.get('headers', {}) | ||
|
|
||
|
|
@@ -62,9 +80,8 @@ async def intercept( | |
| ): | ||
| headers['Authorization'] = f'Bearer {credential}' | ||
| logger.debug( | ||
| "Added Bearer token for scheme '%s' (type: %s).", | ||
| "Added Bearer token for scheme '%s'.", | ||
| scheme_name, | ||
| scheme_def.type, | ||
| ) | ||
| http_kwargs['headers'] = headers | ||
| return request_payload, http_kwargs | ||
|
|
@@ -76,15 +93,14 @@ async def intercept( | |
| ): | ||
| headers['Authorization'] = f'Bearer {credential}' | ||
| logger.debug( | ||
| "Added Bearer token for scheme '%s' (type: %s).", | ||
| "Added Bearer token for scheme '%s'.", | ||
| scheme_name, | ||
| scheme_def.type, | ||
| ) | ||
| http_kwargs['headers'] = headers | ||
| return request_payload, http_kwargs | ||
|
|
||
| # Case 2: API Key in Header | ||
| case APIKeySecurityScheme(in_=In.header): | ||
| case APIKeySecurityScheme() if scheme_def.location.lower() == 'header': | ||
| headers[scheme_def.name] = credential | ||
| logger.debug( | ||
| "Added API Key Header for scheme '%s'.", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to change this back to "main" to pick up the latest fixes from the a2a.proto and before we merge we need to switch to the tag for v1.0 (when its cut)