Skip to content

Commit ac101ec

Browse files
committed
Add minor adjustments in param model/type utils
1 parent 51f9582 commit ac101ec

File tree

4 files changed

+57
-44
lines changed

4 files changed

+57
-44
lines changed

src/openapi_test_client/libraries/api/api_functions/utils/param_model.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,24 @@ def has_param_model(annotated_type: Any) -> bool:
4848
:param annotated_type: Annotated type for a field to check whether it contains a param model or not
4949
"""
5050

51-
inner_type = param_type_util.get_inner_type(annotated_type)
52-
if param_type_util.is_union_type(inner_type):
53-
return any(is_param_model(o) for o in get_args(inner_type))
51+
base_type = param_type_util.get_base_type(annotated_type)
52+
if param_type_util.is_union_type(base_type):
53+
return any(is_param_model(o) for o in get_args(base_type))
5454
else:
55-
return is_param_model(inner_type)
55+
return is_param_model(base_type)
5656

5757

5858
def get_param_model(annotated_type: Any) -> ParamModel | list[ParamModel] | None:
5959
"""Returns a param model from the annotated type, if there is any
6060
6161
:param annotated_type: Annotated type
6262
"""
63-
inner_type = param_type_util.get_inner_type(annotated_type)
64-
if has_param_model(inner_type):
65-
if param_type_util.is_union_type(inner_type):
66-
return [x for x in get_args(inner_type) if has_param_model(x)]
63+
base_type = param_type_util.get_base_type(annotated_type)
64+
if has_param_model(base_type):
65+
if param_type_util.is_union_type(base_type):
66+
return [x for x in get_args(base_type) if has_param_model(x)]
6767
else:
68-
return inner_type
68+
return base_type
6969

7070

7171
@lru_cache

src/openapi_test_client/libraries/api/api_functions/utils/param_type.py

+29-30
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import openapi_test_client.libraries.api.api_functions.utils.param_model as param_model_util
2020
from openapi_test_client.libraries.api.types import Alias, Constraint, Format, ParamAnnotationType, ParamDef
2121
from openapi_test_client.libraries.common.constants import BACKSLASH
22+
from openapi_test_client.libraries.common.misc import dedup
2223

2324
logger = get_logger(__name__)
2425

@@ -37,6 +38,8 @@ def get_type_annotation_as_str(tp: Any) -> str:
3738
"""
3839
if isinstance(tp, str):
3940
return repr(tp)
41+
elif tp is Ellipsis:
42+
return "..."
4043
elif get_origin(tp) is Annotated:
4144
orig_type = get_type_annotation_as_str(tp.__origin__)
4245
metadata_types = ", ".join(get_type_annotation_as_str(m) for m in tp.__metadata__)
@@ -68,7 +71,7 @@ def get_type_annotation_as_str(tp: Any) -> str:
6871
if v is not None
6972
)
7073
return f"{type(tp).__name__}({const})"
71-
elif tp is NoneType:
74+
elif tp in [NoneType, None]:
7275
return "None"
7376
else:
7477
if inspect.isclass(tp):
@@ -192,8 +195,8 @@ def resolve(param_type: str, param_format: str | None = None):
192195
return type_annotation
193196

194197

195-
def get_inner_type(tp: Any, return_if_container_type: bool = False) -> Any | list[Any]:
196-
"""Get the inner type (=actual type) from the type annotation
198+
def get_base_type(tp: Any, return_if_container_type: bool = False) -> Any | list[Any]:
199+
"""Get the base type from the type annotation
197200
198201
eg:
199202
Optional[str] -> str
@@ -214,26 +217,26 @@ def get_inner_type(tp: Any, return_if_container_type: bool = False) -> Any | lis
214217

215218
if is_union_type(tp):
216219
args_without_nonetype = [x for x in get_args(tp) if x is not NoneType]
217-
return generate_union_type([get_inner_type(x) for x in args_without_nonetype])
220+
return generate_union_type([get_base_type(x) for x in args_without_nonetype])
218221
elif origin_type is Annotated:
219-
return get_inner_type(tp.__origin__)
222+
return get_base_type(tp.__origin__)
220223
elif origin_type is list:
221224
if return_if_container_type:
222225
return tp
223226
else:
224-
return get_inner_type(get_args(tp)[0])
227+
return get_base_type(get_args(tp)[0])
225228
return tp
226229

227230

228-
def replace_inner_type(tp: Any, new_type: Any, replace_container_type: bool = False) -> Any:
229-
"""Replace an inner type of in the type annotation to something else
231+
def replace_base_type(tp: Any, new_type: Any, replace_container_type: bool = False) -> Any:
232+
"""Replace the base type of the type annotation to something else
230233
231234
:param tp: The original type annotation
232-
:param new_type: A new type to replace the inner type with
233-
:param replace_container_type: Treat container types like list and tuple as an inner type
235+
:param new_type: A new type to replace the base type with
236+
:param replace_container_type: Treat container types like list and tuple as an base type
234237
235238
>>> tp = Optional[Annotated[int, "metadata"]]
236-
>>> new_tp = replace_inner_type(tp, str)
239+
>>> new_tp = replace_base_type(tp, str)
237240
>>> print(new_tp)
238241
typing.Optional[typing.Annotated[str, 'metadata']]
239242
"""
@@ -242,16 +245,16 @@ def replace_inner_type(tp: Any, new_type: Any, replace_container_type: bool = Fa
242245
args = get_args(tp)
243246
if is_union_type(tp):
244247
if is_optional_type(tp):
245-
return Optional[replace_inner_type(args[0], new_type)] # noqa: UP007
248+
return Optional[replace_base_type(args[0], new_type)] # noqa: UP007
246249
else:
247-
return replace_inner_type(args, new_type)
250+
return replace_base_type(args, new_type)
248251
elif origin_type is Annotated:
249-
return Annotated[replace_inner_type(tp.__origin__, new_type), *tp.__metadata__]
252+
return Annotated[replace_base_type(tp.__origin__, new_type), *tp.__metadata__]
250253
elif origin_type in [list, tuple]:
251254
if replace_container_type:
252255
return new_type
253256
else:
254-
return origin_type[replace_inner_type(args, new_type)]
257+
return origin_type[replace_base_type(args, new_type)]
255258
else:
256259
return new_type
257260
else:
@@ -358,16 +361,22 @@ def generate_optional_type(tp: Any) -> Any:
358361
return Union[tp, None] # noqa: UP007
359362

360363

361-
def generate_annotated_type(tp: Any, metadata: Any):
362-
"""Add `Annotated` type to the type annotation with the metadata
364+
def generate_annotated_type(tp: Any, *metadata: Any):
365+
"""Add `Annotated` type to the type annotation with the metadata.
366+
367+
If the given type is already annotated, the specified metadata will be merged into the existing annotated type
363368
364369
:param tp: Type annotation
370+
:param metadata: Metadata to add to Annotated[]
365371
"""
366-
if is_optional_type(tp):
372+
if get_origin(tp) is Annotated:
373+
# merge metadata
374+
return generate_annotated_type(get_args(tp)[0], *(dedup(*tp.__metadata__, *metadata)))
375+
elif is_optional_type(tp):
367376
inner_type = get_args(tp)[0]
368-
return Optional[Annotated[inner_type, metadata]] # noqa: UP007
377+
return Optional[generate_annotated_type(inner_type, *metadata)] # noqa: UP007
369378
else:
370-
return Annotated[tp, metadata]
379+
return Annotated[tp, *metadata]
371380

372381

373382
def get_annotated_type(tp: Any) -> _AnnotatedAlias | None:
@@ -392,16 +401,6 @@ def merge_annotation_types(tp1: Any, tp2: Any) -> Any:
392401
Note: This is still experimental
393402
"""
394403

395-
def dedup(*args: Any) -> tuple[Any, ...]:
396-
"""Deduplicate items by retaining the order"""
397-
seen = set()
398-
deduped_args = []
399-
for arg in args:
400-
if arg not in seen:
401-
deduped_args.append(arg)
402-
seen.add(arg)
403-
return tuple(deduped_args)
404-
405404
def merge_args_per_origin(args: Sequence[Any]) -> tuple[Any, ...]:
406405
"""Merge type annotations per its origiin type"""
407406
origin_type_order = {Literal: 1, Annotated: 2, Union: 3, UnionType: 4, list: 5, dict: 6, None: 10}

src/openapi_test_client/libraries/api/api_functions/utils/pydantic_model.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def generate_pydantic_model_fields(
9393
pydantic_model_type = param_type_util.generate_union_type(models)
9494
else:
9595
pydantic_model_type = param_model.to_pydantic()
96-
field_type = param_type_util.replace_inner_type(field_type, pydantic_model_type)
96+
field_type = param_type_util.replace_base_type(field_type, pydantic_model_type)
9797

9898
# Adjust field type and value
9999
if param_type_util.is_optional_type(field_type):
@@ -148,9 +148,9 @@ def generate_pydantic_model_fields(
148148
if is_query_param or (
149149
issubclass(original_model, EndpointModel) and original_model.endpoint_func.method.upper() == "GET"
150150
):
151-
inner_type = param_type_util.get_inner_type(field_type)
152-
if get_origin(inner_type) is not list:
153-
field_type = param_type_util.replace_inner_type(field_type, inner_type | list[inner_type])
151+
base_type = param_type_util.get_base_type(field_type)
152+
if get_origin(base_type) is not list:
153+
field_type = param_type_util.replace_base_type(field_type, base_type | list[base_type])
154154

155155
return (field_type, field_value)
156156

@@ -175,6 +175,6 @@ def convert_type_from_param_format(field_type: Any, format: str) -> Any:
175175
:param format: OpenAPI parameter format
176176
"""
177177
if pydantic_type := PARAM_FORMAT_AND_TYPE_MAP.get(format):
178-
return param_type_util.replace_inner_type(field_type, pydantic_type)
178+
return param_type_util.replace_base_type(field_type, pydantic_type)
179179
else:
180180
return field_type

src/openapi_test_client/libraries/common/misc.py

+14
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,17 @@ def reload_recursively(file_or_dir: Path):
115115
reload(file_or_dir)
116116

117117
reload_recursively(root_dir)
118+
119+
120+
def dedup(*objects: Any) -> tuple[Any, ...]:
121+
"""Deduplicate objects by retaining the order
122+
123+
:param objects: Objects to perform deduplication with
124+
"""
125+
seen = set()
126+
deduped = []
127+
for obj in objects:
128+
if obj not in seen:
129+
deduped.append(obj)
130+
seen.add(obj)
131+
return tuple(deduped)

0 commit comments

Comments
 (0)