diff --git a/generators/python/sdk/changes/unreleased/fix-oauth-custom-params.yml b/generators/python/sdk/changes/unreleased/fix-oauth-custom-params.yml new file mode 100644 index 000000000000..e84ba363afdf --- /dev/null +++ b/generators/python/sdk/changes/unreleased/fix-oauth-custom-params.yml @@ -0,0 +1,5 @@ +- summary: | + Fix OAuth token provider to use correct parameter names from request-properties mapping + for custom OAuth endpoints, and pass additional required parameters (scopes, custom properties) + through the client constructor to the token provider. + type: fix diff --git a/generators/python/src/fern_python/generators/sdk/client_generator/oauth_token_provider_generator.py b/generators/python/src/fern_python/generators/sdk/client_generator/oauth_token_provider_generator.py index 66f079228dc5..36bce4b351ae 100644 --- a/generators/python/src/fern_python/generators/sdk/client_generator/oauth_token_provider_generator.py +++ b/generators/python/src/fern_python/generators/sdk/client_generator/oauth_token_provider_generator.py @@ -50,7 +50,7 @@ def _generate_client_credentials_classes( def _create_client_credentials_class_declaration( self, client_credentials: ir_types.OAuthClientCredentials, *, is_async: bool ) -> AST.ClassDeclaration: - constructor_parameters = self._get_constructor_parameters(is_async=is_async) + constructor_parameters = self._get_constructor_parameters(client_credentials, is_async=is_async) named_parameters = [ AST.NamedFunctionParameter( @@ -98,7 +98,9 @@ def _create_client_credentials_class_declaration( return class_declaration - def _get_constructor_parameters(self, *, is_async: bool) -> List[ConstructorParameter]: + def _get_constructor_parameters( + self, client_credentials: ir_types.OAuthClientCredentials, *, is_async: bool + ) -> List[ConstructorParameter]: parameters: List[ConstructorParameter] = [] parameters.append( @@ -117,6 +119,15 @@ def _get_constructor_parameters(self, *, is_async: bool) -> List[ConstructorPara ) ) + for param_name, member_name in self._get_additional_oauth_params(client_credentials): + parameters.append( + ConstructorParameter( + constructor_parameter_name=param_name, + private_member_name=member_name, + type_hint=AST.TypeHint.str_(), + ) + ) + parameters.append( ConstructorParameter( constructor_parameter_name=self._get_client_wrapper_constructor_parameter_name(), @@ -529,6 +540,49 @@ def _get_response_property_path(self, property_path: Optional[List[Union[str, ir return "" return ".".join([resolve_name(name).snake_case.safe_name for name in property_path]) + "." + def _is_literal_type(self, type_reference: ir_types.TypeReference) -> bool: + type_union = type_reference.get_as_union() + if type_union.type == "container": + container_union = type_union.container.get_as_union() + return container_union.type == "literal" + return False + + def _is_optional_type(self, type_reference: ir_types.TypeReference) -> bool: + type_union = type_reference.get_as_union() + if type_union.type == "container": + container_union = type_union.container.get_as_union() + return container_union.type == "optional" or container_union.type == "nullable" + return False + + def _get_additional_oauth_params( + self, client_credentials: ir_types.OAuthClientCredentials + ) -> List[Tuple[str, str]]: + """ + Returns (param_name, member_name) pairs for additional token endpoint parameters + beyond client_id/client_secret: the scopes property and any required non-literal + custom properties. + """ + result: List[Tuple[str, str]] = [] + token_request_properties = client_credentials.token_endpoint.request_properties + + if token_request_properties.scopes is not None: + param_name = self._get_request_property_parameter_name(token_request_properties.scopes) + result.append((param_name, f"_{param_name}")) + + if token_request_properties.custom_properties is not None: + for custom_prop in token_request_properties.custom_properties: + prop_value = custom_prop.property + prop_type = prop_value.visit( + query=lambda q: q.value_type, + body=lambda b: b.value_type, + ) + if self._is_literal_type(prop_type) or self._is_optional_type(prop_type): + continue + param_name = self._get_request_property_parameter_name(custom_prop) + result.append((param_name, f"_{param_name}")) + + return result + def _get_refresh_function_invocation( self, client_credentials: ir_types.OAuthClientCredentials ) -> AST.FunctionInvocation: @@ -547,6 +601,15 @@ def _get_refresh_function_invocation( AST.Expression(f"self.{self._get_client_secret_member_name()}"), ), ] + + for param_name, member_name in self._get_additional_oauth_params(client_credentials): + kwargs.append( + ( + param_name, + AST.Expression(f"self.{member_name}"), + ) + ) + if client_credentials.refresh_endpoint is None: token_endpoint: ir_types.HttpEndpoint = self._get_endpoint_for_id( client_credentials.token_endpoint.endpoint_reference.endpoint_id diff --git a/generators/python/src/fern_python/generators/sdk/client_generator/root_client_generator.py b/generators/python/src/fern_python/generators/sdk/client_generator/root_client_generator.py index ba7ccbc0777a..f6422f08d2c5 100644 --- a/generators/python/src/fern_python/generators/sdk/client_generator/root_client_generator.py +++ b/generators/python/src/fern_python/generators/sdk/client_generator/root_client_generator.py @@ -111,6 +111,53 @@ def _get_wrapper_bearer_token_kwarg_name(self, *, client_wrapper_generator: Clie return "token" return names.get_token_constructor_parameter_name(bearer_auth_scheme) + @staticmethod + def _is_literal_type_reference(type_reference: ir_types.TypeReference) -> bool: + type_union = type_reference.get_as_union() + if type_union.type == "container": + container_union = type_union.container.get_as_union() + return container_union.type == "literal" + return False + + @staticmethod + def _is_optional_type_reference(type_reference: ir_types.TypeReference) -> bool: + type_union = type_reference.get_as_union() + if type_union.type == "container": + container_union = type_union.container.get_as_union() + return container_union.type == "optional" or container_union.type == "nullable" + return False + + @staticmethod + def _get_request_property_param_name(request_property: ir_types.RequestProperty) -> str: + return request_property.property.visit( + query=lambda q: resolve_name(get_name_from_wire_value(q.name)).snake_case.safe_name, + body=lambda b: resolve_name(get_name_from_wire_value(b.name)).snake_case.safe_name, + ) + + def _get_additional_oauth_param_names(self, client_credentials: ir_types.OAuthClientCredentials) -> List[str]: + """ + Returns parameter names for additional token endpoint parameters beyond + client_id/client_secret: the scopes property and any required non-literal + custom properties. + """ + result: List[str] = [] + token_request_properties = client_credentials.token_endpoint.request_properties + + if token_request_properties.scopes is not None: + result.append(self._get_request_property_param_name(token_request_properties.scopes)) + + if token_request_properties.custom_properties is not None: + for custom_prop in token_request_properties.custom_properties: + prop_type = custom_prop.property.visit( + query=lambda q: q.value_type, + body=lambda b: b.value_type, + ) + if self._is_literal_type_reference(prop_type) or self._is_optional_type_reference(prop_type): + continue + result.append(self._get_request_property_param_name(custom_prop)) + + return result + def __init__( self, *, @@ -247,6 +294,14 @@ def _write_root_class_docstring(self, writer: AST.NodeWriter, *, is_async: bool) param.constructor_parameter_name: param for param in constructor_parameters } + # Additional oauth params (scopes, custom properties) + extra_oauth_param_names: list[str] = ( + self._get_additional_oauth_param_names(oauth_union) + if oauth_union is not None and oauth_union.type == "clientCredentials" + else [] + ) + str_type_params = {"client_id", "client_secret", *extra_oauth_param_names} + def write_param_block(param_names: list[str]) -> None: is_first = True for name in param_names: @@ -259,7 +314,7 @@ def write_param_block(param_names: list[str]) -> None: writer.write(f"{param.constructor_parameter_name} : ") if param.type_hint is not None: - if param.constructor_parameter_name in {"client_id", "client_secret"}: + if param.constructor_parameter_name in str_type_params: writer.write_node(AST.TypeHint.str_()) elif param.constructor_parameter_name == RootClientGenerator.TOKEN_PARAMETER_NAME: writer.write_node(AST.TypeHint.callable(parameters=[], return_type=AST.TypeHint.str_())) @@ -269,11 +324,11 @@ def write_param_block(param_names: list[str]) -> None: with writer.indent(): writer.write_line(param.docs) - # Overload 1: client_id + client_secret overload_1_param_names: list[str] = [ RootClientGenerator.BASE_URL_CONSTRUCTOR_PARAMETER_NAME, "client_id", "client_secret", + *extra_oauth_param_names, self._timeout_constructor_parameter_name, self._max_retries_constructor_parameter_name, RootClientGenerator.FOLLOW_REDIRECTS_CONSTRUCTOR_PARAMETER_NAME, @@ -687,6 +742,15 @@ def write_default_environment(writer: AST.NodeWriter) -> None: docs=RootClientGenerator.CLIENT_SECRET_CONSTRUCTOR_PARAMETER_DOCS, ), ) + # Add additional OAuth token endpoint params (scopes, custom properties) + for extra_param_name in self._get_additional_oauth_param_names(oauth): + parameters.append( + RootClientConstructorParameter( + constructor_parameter_name=extra_param_name, + type_hint=AST.TypeHint.optional(AST.TypeHint.str_()), + initializer=AST.Expression("None"), + ), + ) # Add the token parameter for direct token authentication parameters.append( RootClientConstructorParameter( @@ -928,7 +992,14 @@ def _get_constructor_overloads(self, *, is_async: bool) -> Optional[List[AST.Fun type_hint=AST.TypeHint.str_(), ) - oauth_params = base_params + [client_id_param, client_secret_param] + extra_oauth_params = [ + AST.NamedFunctionParameter( + name=extra_name, + type_hint=AST.TypeHint.str_(), + ) + for extra_name in self._get_additional_oauth_param_names(oauth) + ] + oauth_params = base_params + [client_id_param, client_secret_param] + extra_oauth_params oauth_signature = AST.FunctionSignature(named_parameters=oauth_params) # Overload 2: Direct token (token required) @@ -943,8 +1014,13 @@ def _get_constructor_overloads(self, *, is_async: bool) -> Optional[List[AST.Fun return [oauth_signature, token_signature] def _get_non_oauth_constructor_parameters(self, *, is_async: bool) -> List[AST.NamedFunctionParameter]: - """Get constructor parameters excluding OAuth-specific ones (client_id, client_secret, token).""" + """Get constructor parameters excluding OAuth-specific ones (client_id, client_secret, token, extra oauth params).""" oauth_param_names = {"client_id", "client_secret", self.TOKEN_PARAMETER_NAME, self.TOKEN_GETTER_PARAM_NAME} + if self._oauth_scheme is not None: + oauth_config = self._oauth_scheme.configuration.get_as_union() + if oauth_config.type == "clientCredentials": + for extra_name in self._get_additional_oauth_param_names(oauth_config): + oauth_param_names.add(extra_name) all_params = self._get_constructor_parameters(is_async=is_async) return [ AST.NamedFunctionParameter( @@ -1044,30 +1120,36 @@ def _write_constructor_body(writer: AST.NodeWriter) -> None: if is_async else self._context.core_utilities.get_oauth_token_provider() ) + oauth_tp_kwargs = [ + ( + "client_id", + AST.Expression("client_id"), + ), + ( + "client_secret", + AST.Expression("client_secret"), + ), + ] + if is_oauth_client_credentials and oauth_union is not None: + for extra_param_name in self._get_additional_oauth_param_names(oauth_union): + oauth_tp_kwargs.append( + (extra_param_name, AST.Expression(extra_param_name)), + ) + oauth_tp_kwargs.append( + ( + "client_wrapper", + AST.Expression( + AST.ClassInstantiation( + class_=self._context.core_utilities.get_reference_to_client_wrapper(is_async=is_async), + kwargs=client_wrapper_constructor_kwargs, + ), + ), + ), + ) writer.write_node( AST.ClassInstantiation( class_=oauth_token_provider_class, - kwargs=[ - ( - "client_id", - AST.Expression("client_id"), - ), - ( - "client_secret", - AST.Expression("client_secret"), - ), - ( - "client_wrapper", - AST.Expression( - AST.ClassInstantiation( - class_=self._context.core_utilities.get_reference_to_client_wrapper( - is_async=is_async - ), - kwargs=client_wrapper_constructor_kwargs, - ), - ), - ), - ], + kwargs=oauth_tp_kwargs, ) ) writer.write_newline_if_last_line_not() @@ -1285,6 +1367,12 @@ def _write_oauth_token_override_constructor_body( # elif client_id is not None and client_secret is not None: writer.write_line("elif client_id is not None and client_secret is not None:") with writer.indent(): + # Narrow additional OAuth params from Optional[str] to str + if self._oauth_scheme is not None: + override_oauth_config = self._oauth_scheme.configuration.get_as_union() + if override_oauth_config.type == "clientCredentials": + for extra_name in self._get_additional_oauth_param_names(override_oauth_config): + writer.write_line(f"assert {extra_name} is not None") # OAuth client credentials mode oauth_client_wrapper_kwargs = self._get_client_wrapper_kwargs( client_wrapper_generator=client_wrapper_generator, @@ -1302,30 +1390,38 @@ def _write_oauth_token_override_constructor_body( if is_async else self._context.core_utilities.get_oauth_token_provider() ) + override_oauth_tp_kwargs = [ + ( + "client_id", + AST.Expression("client_id"), + ), + ( + "client_secret", + AST.Expression("client_secret"), + ), + ] + if self._oauth_scheme is not None: + override_oauth = self._oauth_scheme.configuration.get_as_union() + if override_oauth.type == "clientCredentials": + for extra_param_name in self._get_additional_oauth_param_names(override_oauth): + override_oauth_tp_kwargs.append( + (extra_param_name, AST.Expression(extra_param_name)), + ) + override_oauth_tp_kwargs.append( + ( + "client_wrapper", + AST.Expression( + AST.ClassInstantiation( + class_=self._context.core_utilities.get_reference_to_client_wrapper(is_async=is_async), + kwargs=oauth_client_wrapper_kwargs, + ), + ), + ), + ) writer.write_node( AST.ClassInstantiation( class_=oauth_token_provider_class, - kwargs=[ - ( - "client_id", - AST.Expression("client_id"), - ), - ( - "client_secret", - AST.Expression("client_secret"), - ), - ( - "client_wrapper", - AST.Expression( - AST.ClassInstantiation( - class_=self._context.core_utilities.get_reference_to_client_wrapper( - is_async=is_async - ), - kwargs=oauth_client_wrapper_kwargs, - ), - ), - ), - ], + kwargs=override_oauth_tp_kwargs, ) ) writer.write_newline_if_last_line_not() diff --git a/seed/python-sdk/oauth-client-credentials-custom/.fern/metadata.json b/seed/python-sdk/oauth-client-credentials-custom/.fern/metadata.json index 0ee76b66eb30..d018d66316a6 100644 --- a/seed/python-sdk/oauth-client-credentials-custom/.fern/metadata.json +++ b/seed/python-sdk/oauth-client-credentials-custom/.fern/metadata.json @@ -3,8 +3,7 @@ "generatorName": "fernapi/fern-python-sdk", "generatorVersion": "local", "originGitCommit": "DUMMY", - "invokedBy": "ci", + "invokedBy": "manual", "requestedVersion": "0.0.1", - "ciProvider": "github", "sdkVersion": "0.0.1" } \ No newline at end of file diff --git a/seed/python-sdk/oauth-client-credentials-custom/src/seed/client.py b/seed/python-sdk/oauth-client-credentials-custom/src/seed/client.py index 14165e93bb5b..3501cc3d57ed 100644 --- a/seed/python-sdk/oauth-client-credentials-custom/src/seed/client.py +++ b/seed/python-sdk/oauth-client-credentials-custom/src/seed/client.py @@ -33,6 +33,8 @@ class SeedOauthClientCredentials: client_secret : str The client secret used for authentication. + scp : str + entity_id : str timeout : typing.Optional[float] The timeout to be used, in seconds, for requests. By default the timeout is 60 seconds, unless a custom httpx client is used, in which case this default is not enforced. @@ -98,6 +100,8 @@ def __init__( logging: typing.Optional[typing.Union[LogConfig, Logger]] = None, client_id: str, client_secret: str, + scp: str, + entity_id: str, ): ... @typing.overload def __init__( @@ -119,6 +123,8 @@ def __init__( headers: typing.Optional[typing.Dict[str, str]] = None, client_id: typing.Optional[str] = None, client_secret: typing.Optional[str] = None, + scp: typing.Optional[str] = None, + entity_id: typing.Optional[str] = None, token: typing.Optional[typing.Callable[[], str]] = None, _token_getter_override: typing.Optional[typing.Callable[[], str]] = None, timeout: typing.Optional[float] = None, @@ -146,9 +152,13 @@ def __init__( token=_token_getter_override if _token_getter_override is not None else token, ) elif client_id is not None and client_secret is not None: + assert scp is not None + assert entity_id is not None oauth_token_provider = OAuthTokenProvider( client_id=client_id, client_secret=client_secret, + scp=scp, + entity_id=entity_id, client_wrapper=SyncClientWrapper( base_url=base_url, headers=headers, @@ -251,6 +261,8 @@ class AsyncSeedOauthClientCredentials: client_secret : str The client secret used for authentication. + scp : str + entity_id : str timeout : typing.Optional[float] The timeout to be used, in seconds, for requests. By default the timeout is 60 seconds, unless a custom httpx client is used, in which case this default is not enforced. @@ -316,6 +328,8 @@ def __init__( logging: typing.Optional[typing.Union[LogConfig, Logger]] = None, client_id: str, client_secret: str, + scp: str, + entity_id: str, ): ... @typing.overload def __init__( @@ -337,6 +351,8 @@ def __init__( headers: typing.Optional[typing.Dict[str, str]] = None, client_id: typing.Optional[str] = None, client_secret: typing.Optional[str] = None, + scp: typing.Optional[str] = None, + entity_id: typing.Optional[str] = None, token: typing.Optional[typing.Callable[[], str]] = None, _token_getter_override: typing.Optional[typing.Callable[[], str]] = None, timeout: typing.Optional[float] = None, @@ -362,9 +378,13 @@ def __init__( token=_token_getter_override if _token_getter_override is not None else token, ) elif client_id is not None and client_secret is not None: + assert scp is not None + assert entity_id is not None oauth_token_provider = AsyncOAuthTokenProvider( client_id=client_id, client_secret=client_secret, + scp=scp, + entity_id=entity_id, client_wrapper=AsyncClientWrapper( base_url=base_url, headers=headers, diff --git a/seed/python-sdk/oauth-client-credentials-custom/src/seed/core/oauth_token_provider.py b/seed/python-sdk/oauth-client-credentials-custom/src/seed/core/oauth_token_provider.py index cbdee3fa15c1..11ee82973fd4 100644 --- a/seed/python-sdk/oauth-client-credentials-custom/src/seed/core/oauth_token_provider.py +++ b/seed/python-sdk/oauth-client-credentials-custom/src/seed/core/oauth_token_provider.py @@ -14,9 +14,13 @@ class OAuthTokenProvider: BUFFER_IN_MINUTES = 2 - def __init__(self, *, client_id: str, client_secret: str, client_wrapper: SyncClientWrapper): + def __init__( + self, *, client_id: str, client_secret: str, scp: str, entity_id: str, client_wrapper: SyncClientWrapper + ): self._client_id = client_id self._client_secret = client_secret + self._scp = scp + self._entity_id = entity_id self._access_token: typing.Optional[str] = None self._expires_at: dt.datetime = dt.datetime.now() self._auth_client = AuthClient(client_wrapper=client_wrapper) @@ -32,7 +36,7 @@ def get_token(self) -> str: def _refresh(self) -> str: token_response = self._auth_client.get_token_with_client_credentials( - cid=self._client_id, csr=self._client_secret + cid=self._client_id, csr=self._client_secret, scp=self._scp, entity_id=self._entity_id ) self._access_token = token_response.access_token self._expires_at = self._get_expires_at( @@ -47,9 +51,13 @@ def _get_expires_at(self, *, expires_in_seconds: int, buffer_in_minutes: int): class AsyncOAuthTokenProvider: BUFFER_IN_MINUTES = 2 - def __init__(self, *, client_id: str, client_secret: str, client_wrapper: AsyncClientWrapper): + def __init__( + self, *, client_id: str, client_secret: str, scp: str, entity_id: str, client_wrapper: AsyncClientWrapper + ): self._client_id = client_id self._client_secret = client_secret + self._scp = scp + self._entity_id = entity_id self._access_token: typing.Optional[str] = None self._expires_at: dt.datetime = dt.datetime.now() self._auth_client = AsyncAuthClient(client_wrapper=client_wrapper) @@ -65,7 +73,7 @@ async def get_token(self) -> str: async def _refresh(self) -> str: token_response = await self._auth_client.get_token_with_client_credentials( - cid=self._client_id, csr=self._client_secret + cid=self._client_id, csr=self._client_secret, scp=self._scp, entity_id=self._entity_id ) self._access_token = token_response.access_token self._expires_at = self._get_expires_at( diff --git a/seed/python-sdk/seed.yml b/seed/python-sdk/seed.yml index 0f2beb77b38c..729ab1947cdd 100644 --- a/seed/python-sdk/seed.yml +++ b/seed/python-sdk/seed.yml @@ -500,7 +500,6 @@ scripts: allowedFailures: - examples:legacy-wire-tests - exhaustive:deps_with_min_python_version - - oauth-client-credentials-custom - streaming-parameter - trace - unions:union-naming-v1-wire-tests