Skip to content

Commit 513f65a

Browse files
committed
Update custom or_ to support union of a param model and its forward ref
Add 'self' to the illegal param name list
1 parent 38593ce commit 513f65a

File tree

4 files changed

+31
-6
lines changed

4 files changed

+31
-6
lines changed

src/openapi_test_client/libraries/api/api_functions/utils/endpoint_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
7575
# Some OpenAPI specs define a parameter name using characters we can't use as a python variable name.
7676
# We will use the cleaned name as the model field and annotate it as `Annotated[field_type, Alias(<original_val>)]`
7777
# When calling an endpoint function, the actual name will be automatically resolved in the payload/query parameters
78-
param_model_util.alias_illegal_model_field_names(path_param_fields)
79-
param_model_util.alias_illegal_model_field_names(body_or_query_param_fields)
78+
param_model_util.alias_illegal_model_field_names(model_name, path_param_fields)
79+
param_model_util.alias_illegal_model_field_names(model_name, body_or_query_param_fields)
8080

8181
fields = path_param_fields + body_or_query_param_fields
8282
return cast(

src/openapi_test_client/libraries/api/api_functions/utils/param_model.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,20 @@ def get_param_model(annotated_type: Any) -> type[ParamModel] | ForwardRef | list
6969
return base_type
7070

7171

72+
def get_param_model_name(param_model: type[ParamModel] | ForwardRef) -> str:
73+
"""Get the model name
74+
75+
:param param_model: Param model. This can be a forward ref
76+
"""
77+
if not is_param_model(param_model):
78+
raise ValueError(f"{param_model} is not a param model")
79+
80+
if isinstance(param_model, ForwardRef):
81+
return param_model.__forward_arg__
82+
else:
83+
return param_model.__name__
84+
85+
7286
@lru_cache
7387
def generate_model_name(field_name: str, field_type: str | Any) -> str:
7488
"""Generate model name from the given field
@@ -130,7 +144,7 @@ def create_model_from_param_def(
130144
)
131145
for inner_param_name, inner_param_obj in param_def.get("properties", {}).items()
132146
]
133-
alias_illegal_model_field_names(fields)
147+
alias_illegal_model_field_names(model_name, fields)
134148
return cast(
135149
type[ParamModel],
136150
make_dataclass(
@@ -299,9 +313,10 @@ def visit(model_name: str) -> None:
299313
return sorted(models, key=lambda x: sorted_models_names.index(x.__name__))
300314

301315

302-
def alias_illegal_model_field_names(model_fields: list[DataclassModelField]) -> str:
316+
def alias_illegal_model_field_names(model_name: str, model_fields: list[DataclassModelField]) -> None:
303317
"""Clean illegal model field name and annotate the field type with Alias class
304318
319+
:param model_name: Model name
305320
:param model_fields: fields value to be passed to make_dataclass()
306321
"""
307322

@@ -315,7 +330,7 @@ def make_alias(name: str, param_type: Any) -> str:
315330
else:
316331
name = clean_obj_name(name)
317332
# NOTE: The escaping of kwargs is already is handled in endpoint model
318-
reserved_param_names = [*get_supported_request_parameters(), "validate"]
333+
reserved_param_names = ["self", "validate", *get_supported_request_parameters()]
319334
if name in get_reserved_model_names() + reserved_param_names:
320335
# The field name conflicts with one of reserved names
321336
name += "_"

src/openapi_test_client/libraries/api/api_functions/utils/param_type.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,11 @@ def merge_args_per_origin(args: Sequence[Any]) -> tuple[Any, ...]:
463463
for args_ in args_per_origin.values()
464464
)
465465

466+
if isinstance(tp1, str):
467+
tp1 = ForwardRef(tp1)
468+
if isinstance(tp2, str):
469+
tp2 = ForwardRef(tp2)
470+
466471
origin = get_origin(tp1)
467472
origin2 = get_origin(tp2)
468473
if origin or origin2:
@@ -528,7 +533,11 @@ def or_(x: Any, y: Any) -> Any:
528533
>>> reduce(or_, [Model1 | None, Model2])
529534
__main__.MyModel | None
530535
"""
531-
if param_model_util.is_param_model(x) and param_model_util.is_param_model(y) and x.__name__ == y.__name__:
536+
if (
537+
param_model_util.is_param_model(x)
538+
and param_model_util.is_param_model(y)
539+
and param_model_util.get_param_model_name(x) == param_model_util.get_param_model_name(y)
540+
):
532541
return x
533542
else:
534543
is_x_union = is_union_type(x)

tests/unit/test_param_type.py

+1
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ def test_merge_annotation_types(tp1: Any, tp2: Any, expected_type: Any) -> None:
468468
(MyParamModel, MyParamModel2 | None, MyParamModel | None),
469469
(MyParamModel2 | None, MyParamModel3, MyParamModel2 | None),
470470
(MyParamModel2, MyParamModel3 | None, MyParamModel2 | None),
471+
(MyParamModel, ForwardRef(MyParamModel.__name__), MyParamModel),
471472
],
472473
)
473474
def test_custom_or_(tp1: Any, tp2: Any, expected_type: Any) -> None:

0 commit comments

Comments
 (0)