Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- summary: |
Fix OAuth token provider to use correct parameter names from request-properties mapping
for custom OAuth endpoints, and pass additional required parameters (scopes, custom properties)
through the client constructor to the token provider.
type: fix
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _generate_client_credentials_classes(
def _create_client_credentials_class_declaration(
self, client_credentials: ir_types.OAuthClientCredentials, *, is_async: bool
) -> AST.ClassDeclaration:
constructor_parameters = self._get_constructor_parameters(is_async=is_async)
constructor_parameters = self._get_constructor_parameters(client_credentials, is_async=is_async)

named_parameters = [
AST.NamedFunctionParameter(
Expand Down Expand Up @@ -98,7 +98,9 @@ def _create_client_credentials_class_declaration(

return class_declaration

def _get_constructor_parameters(self, *, is_async: bool) -> List[ConstructorParameter]:
def _get_constructor_parameters(
self, client_credentials: ir_types.OAuthClientCredentials, *, is_async: bool
) -> List[ConstructorParameter]:
parameters: List[ConstructorParameter] = []

parameters.append(
Expand All @@ -117,6 +119,15 @@ def _get_constructor_parameters(self, *, is_async: bool) -> List[ConstructorPara
)
)

for param_name, member_name in self._get_additional_oauth_params(client_credentials):
parameters.append(
ConstructorParameter(
constructor_parameter_name=param_name,
private_member_name=member_name,
type_hint=AST.TypeHint.str_(),
)
)

parameters.append(
ConstructorParameter(
constructor_parameter_name=self._get_client_wrapper_constructor_parameter_name(),
Expand Down Expand Up @@ -529,6 +540,49 @@ def _get_response_property_path(self, property_path: Optional[List[Union[str, ir
return ""
return ".".join([resolve_name(name).snake_case.safe_name for name in property_path]) + "."

def _is_literal_type(self, type_reference: ir_types.TypeReference) -> bool:
type_union = type_reference.get_as_union()
if type_union.type == "container":
container_union = type_union.container.get_as_union()
return container_union.type == "literal"
return False

def _is_optional_type(self, type_reference: ir_types.TypeReference) -> bool:
type_union = type_reference.get_as_union()
if type_union.type == "container":
container_union = type_union.container.get_as_union()
return container_union.type == "optional" or container_union.type == "nullable"
return False

def _get_additional_oauth_params(
self, client_credentials: ir_types.OAuthClientCredentials
) -> List[Tuple[str, str]]:
"""
Returns (param_name, member_name) pairs for additional token endpoint parameters
beyond client_id/client_secret: the scopes property and any required non-literal
custom properties.
"""
result: List[Tuple[str, str]] = []
token_request_properties = client_credentials.token_endpoint.request_properties

if token_request_properties.scopes is not None:
param_name = self._get_request_property_parameter_name(token_request_properties.scopes)
result.append((param_name, f"_{param_name}"))

if token_request_properties.custom_properties is not None:
for custom_prop in token_request_properties.custom_properties:
prop_value = custom_prop.property
prop_type = prop_value.visit(
query=lambda q: q.value_type,
body=lambda b: b.value_type,
)
if self._is_literal_type(prop_type) or self._is_optional_type(prop_type):
continue
param_name = self._get_request_property_parameter_name(custom_prop)
result.append((param_name, f"_{param_name}"))

return result

def _get_refresh_function_invocation(
self, client_credentials: ir_types.OAuthClientCredentials
) -> AST.FunctionInvocation:
Expand All @@ -547,6 +601,15 @@ def _get_refresh_function_invocation(
AST.Expression(f"self.{self._get_client_secret_member_name()}"),
),
]

for param_name, member_name in self._get_additional_oauth_params(client_credentials):
kwargs.append(
(
param_name,
AST.Expression(f"self.{member_name}"),
)
)

if client_credentials.refresh_endpoint is None:
token_endpoint: ir_types.HttpEndpoint = self._get_endpoint_for_id(
client_credentials.token_endpoint.endpoint_reference.endpoint_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,53 @@ def _get_wrapper_bearer_token_kwarg_name(self, *, client_wrapper_generator: Clie
return "token"
return names.get_token_constructor_parameter_name(bearer_auth_scheme)

@staticmethod
def _is_literal_type_reference(type_reference: ir_types.TypeReference) -> bool:
type_union = type_reference.get_as_union()
if type_union.type == "container":
container_union = type_union.container.get_as_union()
return container_union.type == "literal"
return False

@staticmethod
def _is_optional_type_reference(type_reference: ir_types.TypeReference) -> bool:
type_union = type_reference.get_as_union()
if type_union.type == "container":
container_union = type_union.container.get_as_union()
return container_union.type == "optional" or container_union.type == "nullable"
return False

@staticmethod
def _get_request_property_param_name(request_property: ir_types.RequestProperty) -> str:
return request_property.property.visit(
query=lambda q: resolve_name(get_name_from_wire_value(q.name)).snake_case.safe_name,
body=lambda b: resolve_name(get_name_from_wire_value(b.name)).snake_case.safe_name,
)

def _get_additional_oauth_param_names(self, client_credentials: ir_types.OAuthClientCredentials) -> List[str]:
"""
Returns parameter names for additional token endpoint parameters beyond
client_id/client_secret: the scopes property and any required non-literal
custom properties.
"""
result: List[str] = []
token_request_properties = client_credentials.token_endpoint.request_properties

if token_request_properties.scopes is not None:
result.append(self._get_request_property_param_name(token_request_properties.scopes))

if token_request_properties.custom_properties is not None:
for custom_prop in token_request_properties.custom_properties:
prop_type = custom_prop.property.visit(
query=lambda q: q.value_type,
body=lambda b: b.value_type,
)
if self._is_literal_type_reference(prop_type) or self._is_optional_type_reference(prop_type):
continue
result.append(self._get_request_property_param_name(custom_prop))

return result

def __init__(
self,
*,
Expand Down Expand Up @@ -247,6 +294,14 @@ def _write_root_class_docstring(self, writer: AST.NodeWriter, *, is_async: bool)
param.constructor_parameter_name: param for param in constructor_parameters
}

# Additional oauth params (scopes, custom properties)
extra_oauth_param_names: list[str] = (
self._get_additional_oauth_param_names(oauth_union)
if oauth_union is not None and oauth_union.type == "clientCredentials"
else []
)
str_type_params = {"client_id", "client_secret", *extra_oauth_param_names}

def write_param_block(param_names: list[str]) -> None:
is_first = True
for name in param_names:
Expand All @@ -259,7 +314,7 @@ def write_param_block(param_names: list[str]) -> None:

writer.write(f"{param.constructor_parameter_name} : ")
if param.type_hint is not None:
if param.constructor_parameter_name in {"client_id", "client_secret"}:
if param.constructor_parameter_name in str_type_params:
writer.write_node(AST.TypeHint.str_())
elif param.constructor_parameter_name == RootClientGenerator.TOKEN_PARAMETER_NAME:
writer.write_node(AST.TypeHint.callable(parameters=[], return_type=AST.TypeHint.str_()))
Expand All @@ -269,11 +324,11 @@ def write_param_block(param_names: list[str]) -> None:
with writer.indent():
writer.write_line(param.docs)

# Overload 1: client_id + client_secret
overload_1_param_names: list[str] = [
RootClientGenerator.BASE_URL_CONSTRUCTOR_PARAMETER_NAME,
"client_id",
"client_secret",
*extra_oauth_param_names,
self._timeout_constructor_parameter_name,
self._max_retries_constructor_parameter_name,
RootClientGenerator.FOLLOW_REDIRECTS_CONSTRUCTOR_PARAMETER_NAME,
Expand Down Expand Up @@ -687,6 +742,15 @@ def write_default_environment(writer: AST.NodeWriter) -> None:
docs=RootClientGenerator.CLIENT_SECRET_CONSTRUCTOR_PARAMETER_DOCS,
),
)
# Add additional OAuth token endpoint params (scopes, custom properties)
for extra_param_name in self._get_additional_oauth_param_names(oauth):
parameters.append(
RootClientConstructorParameter(
constructor_parameter_name=extra_param_name,
type_hint=AST.TypeHint.optional(AST.TypeHint.str_()),
initializer=AST.Expression("None"),
),
)
# Add the token parameter for direct token authentication
parameters.append(
RootClientConstructorParameter(
Expand Down Expand Up @@ -928,7 +992,14 @@ def _get_constructor_overloads(self, *, is_async: bool) -> Optional[List[AST.Fun
type_hint=AST.TypeHint.str_(),
)

oauth_params = base_params + [client_id_param, client_secret_param]
extra_oauth_params = [
AST.NamedFunctionParameter(
name=extra_name,
type_hint=AST.TypeHint.str_(),
)
for extra_name in self._get_additional_oauth_param_names(oauth)
]
oauth_params = base_params + [client_id_param, client_secret_param] + extra_oauth_params
oauth_signature = AST.FunctionSignature(named_parameters=oauth_params)

# Overload 2: Direct token (token required)
Expand All @@ -943,8 +1014,13 @@ def _get_constructor_overloads(self, *, is_async: bool) -> Optional[List[AST.Fun
return [oauth_signature, token_signature]

def _get_non_oauth_constructor_parameters(self, *, is_async: bool) -> List[AST.NamedFunctionParameter]:
"""Get constructor parameters excluding OAuth-specific ones (client_id, client_secret, token)."""
"""Get constructor parameters excluding OAuth-specific ones (client_id, client_secret, token, extra oauth params)."""
oauth_param_names = {"client_id", "client_secret", self.TOKEN_PARAMETER_NAME, self.TOKEN_GETTER_PARAM_NAME}
if self._oauth_scheme is not None:
oauth_config = self._oauth_scheme.configuration.get_as_union()
if oauth_config.type == "clientCredentials":
for extra_name in self._get_additional_oauth_param_names(oauth_config):
oauth_param_names.add(extra_name)
all_params = self._get_constructor_parameters(is_async=is_async)
return [
AST.NamedFunctionParameter(
Expand Down Expand Up @@ -1044,30 +1120,36 @@ def _write_constructor_body(writer: AST.NodeWriter) -> None:
if is_async
else self._context.core_utilities.get_oauth_token_provider()
)
oauth_tp_kwargs = [
(
"client_id",
AST.Expression("client_id"),
),
(
"client_secret",
AST.Expression("client_secret"),
),
]
if is_oauth_client_credentials and oauth_union is not None:
for extra_param_name in self._get_additional_oauth_param_names(oauth_union):
oauth_tp_kwargs.append(
(extra_param_name, AST.Expression(extra_param_name)),
)
oauth_tp_kwargs.append(
(
"client_wrapper",
AST.Expression(
AST.ClassInstantiation(
class_=self._context.core_utilities.get_reference_to_client_wrapper(is_async=is_async),
kwargs=client_wrapper_constructor_kwargs,
),
),
),
)
writer.write_node(
AST.ClassInstantiation(
class_=oauth_token_provider_class,
kwargs=[
(
"client_id",
AST.Expression("client_id"),
),
(
"client_secret",
AST.Expression("client_secret"),
),
(
"client_wrapper",
AST.Expression(
AST.ClassInstantiation(
class_=self._context.core_utilities.get_reference_to_client_wrapper(
is_async=is_async
),
kwargs=client_wrapper_constructor_kwargs,
),
),
),
],
kwargs=oauth_tp_kwargs,
)
)
writer.write_newline_if_last_line_not()
Expand Down Expand Up @@ -1285,6 +1367,12 @@ def _write_oauth_token_override_constructor_body(
# elif client_id is not None and client_secret is not None:
writer.write_line("elif client_id is not None and client_secret is not None:")
with writer.indent():
# Narrow additional OAuth params from Optional[str] to str
if self._oauth_scheme is not None:
override_oauth_config = self._oauth_scheme.configuration.get_as_union()
if override_oauth_config.type == "clientCredentials":
for extra_name in self._get_additional_oauth_param_names(override_oauth_config):
writer.write_line(f"assert {extra_name} is not None")
# OAuth client credentials mode
oauth_client_wrapper_kwargs = self._get_client_wrapper_kwargs(
client_wrapper_generator=client_wrapper_generator,
Expand All @@ -1302,30 +1390,38 @@ def _write_oauth_token_override_constructor_body(
if is_async
else self._context.core_utilities.get_oauth_token_provider()
)
override_oauth_tp_kwargs = [
(
"client_id",
AST.Expression("client_id"),
),
(
"client_secret",
AST.Expression("client_secret"),
),
]
if self._oauth_scheme is not None:
override_oauth = self._oauth_scheme.configuration.get_as_union()
if override_oauth.type == "clientCredentials":
for extra_param_name in self._get_additional_oauth_param_names(override_oauth):
override_oauth_tp_kwargs.append(
(extra_param_name, AST.Expression(extra_param_name)),
)
override_oauth_tp_kwargs.append(
(
"client_wrapper",
AST.Expression(
AST.ClassInstantiation(
class_=self._context.core_utilities.get_reference_to_client_wrapper(is_async=is_async),
kwargs=oauth_client_wrapper_kwargs,
),
),
),
)
writer.write_node(
AST.ClassInstantiation(
class_=oauth_token_provider_class,
kwargs=[
(
"client_id",
AST.Expression("client_id"),
),
(
"client_secret",
AST.Expression("client_secret"),
),
(
"client_wrapper",
AST.Expression(
AST.ClassInstantiation(
class_=self._context.core_utilities.get_reference_to_client_wrapper(
is_async=is_async
),
kwargs=oauth_client_wrapper_kwargs,
),
),
),
],
kwargs=override_oauth_tp_kwargs,
)
)
writer.write_newline_if_last_line_not()
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading