10
10
from common_libs .logging import get_logger
11
11
12
12
import openapi_test_client .libraries .api .api_functions .utils .param_model as param_model_util
13
- from openapi_test_client .libraries .api .types import Alias , Constraint , Format , ParamAnnotationType , ParamDef
13
+ from openapi_test_client .libraries .api .types import (
14
+ Alias ,
15
+ Constraint ,
16
+ Format ,
17
+ ParamAnnotationType ,
18
+ ParamDef ,
19
+ UncacheableLiteralArg ,
20
+ )
21
+ from openapi_test_client .libraries .api .types import Optional as Optional_
14
22
from openapi_test_client .libraries .common .constants import BACKSLASH
15
23
from openapi_test_client .libraries .common .misc import dedup
16
24
17
25
if TYPE_CHECKING :
18
- from typing import _AnnotatedAlias # type: ignore
26
+ from typing import _AnnotatedAlias , _LiteralGenericAlias # type: ignore[attr-defined]
19
27
20
28
21
29
logger = get_logger (__name__ )
@@ -161,7 +169,7 @@ def resolve(param_type: str, param_format: str | None = None) -> Any:
161
169
)
162
170
else :
163
171
if enum := param_def .get ("enum" ):
164
- type_annotation = Literal [ * enum ]
172
+ type_annotation = generate_literal_type ( * enum )
165
173
elif isinstance (param_def , ParamDef .UnknownType ):
166
174
logger .warning (
167
175
f"Param '{ param_name } ': Unable to locate a parameter type in the following parameter object. "
@@ -248,11 +256,11 @@ def replace_base_type(tp: Any, new_type: Any, replace_container_type: bool = Fal
248
256
args = get_args (tp )
249
257
if is_union_type (tp ):
250
258
if is_optional_type (tp ):
251
- return Optional [ replace_base_type (args [0 ], new_type )] # noqa: UP007
259
+ return generate_optional_type ( replace_base_type (args [0 ], new_type ))
252
260
else :
253
261
return replace_base_type (args , new_type )
254
262
elif origin_type is Annotated :
255
- return Annotated [ replace_base_type (tp .__origin__ , new_type ), * tp .__metadata__ ]
263
+ return annotate_type ( replace_base_type (tp .__origin__ , new_type ), * tp .__metadata__ )
256
264
elif origin_type in [list , tuple ]:
257
265
if replace_container_type :
258
266
return new_type
@@ -290,7 +298,7 @@ def is_type_of(param_type: str | Any, type_to_check: Any) -> bool:
290
298
# Add if needed
291
299
raise NotImplementedError
292
300
elif origin_type := get_origin (param_type ):
293
- if origin_type is type_to_check :
301
+ if ( origin_type is type_to_check ) or ( origin_type is Union and type_to_check in [ Optional , Optional_ ]) :
294
302
return True
295
303
elif origin_type is Annotated :
296
304
return is_type_of (param_type .__origin__ , type_to_check )
@@ -345,23 +353,23 @@ def generate_union_type(type_annotations: Sequence[Any]) -> Any:
345
353
346
354
347
355
def generate_optional_type (tp : Any ) -> Any :
348
- """Convert the type annotation to Optional[tp]
349
-
350
- Wrap the type with `Optional[]`, but using `Union` with None instead as there seems to be a cache issue
351
- where `Optional[Literal['val1', 'val2']]` with change the order of Literal parameters due to the cache.
352
-
353
- eg. The issue seen in Python 3.11
354
- >>> t1 = Literal["foo", "bar"]
355
- >>> Optional[t1]
356
- typing.Optional[typing.Literal['foo', 'bar']]
357
- >>> t2 = Literal["bar", "foo"]
358
- >>> Optional[t2]
359
- typing.Optional[typing.Literal['foo', 'bar']] <--- HERE
360
- """
356
+ """Convert the type annotation to Optional[tp]"""
361
357
if is_optional_type (tp ):
362
358
return tp
363
359
else :
364
- return Union [tp , None ] # noqa: UP007
360
+ return Optional [tp ] # noqa: UP007
361
+
362
+
363
+ def generate_literal_type (* args : Any , uncacheable : bool = True ) -> _LiteralGenericAlias :
364
+ """Generate a Literal type annotation using given args
365
+
366
+ :param args: Literal args
367
+ :param uncacheable: Make this Literal type uncacheable
368
+ """
369
+ if uncacheable :
370
+ cacheable_args = tuple (arg .obj if isinstance (arg , UncacheableLiteralArg ) else arg for arg in args )
371
+ args = tuple (UncacheableLiteralArg (arg ) for arg in dedup (* cacheable_args ))
372
+ return Literal [* args ]
365
373
366
374
367
375
def annotate_type (tp : Any , * metadata : Any ) -> Any :
@@ -378,7 +386,7 @@ def annotate_type(tp: Any, *metadata: Any) -> Any:
378
386
return modify_annotated_metadata (tp , * metadata , action = "add" )
379
387
elif is_optional_type (tp ):
380
388
inner_type = generate_union_type ([x for x in get_args (tp ) if x is not NoneType ])
381
- return Optional [ annotate_type (inner_type , * metadata )] # noqa: UP007
389
+ return generate_optional_type ( annotate_type (inner_type , * metadata ))
382
390
else :
383
391
return Annotated [tp , * metadata ]
384
392
@@ -411,7 +419,7 @@ def modify_metadata(tp: Any) -> Any:
411
419
raise ValueError ("At least one metadata must exist after the action is performed" )
412
420
else :
413
421
new_metadata = dedup (* metadata )
414
- return Annotated [ get_args (tp )[0 ], * new_metadata ]
422
+ return annotate_type ( get_args (tp )[0 ], * new_metadata )
415
423
else :
416
424
if is_union_type (tp ):
417
425
return generate_union_type ([modify_metadata (arg ) for arg in get_args (tp )])
@@ -478,7 +486,7 @@ def merge_args_per_origin(args: Sequence[Any]) -> tuple[Any, ...]:
478
486
# stop using set here
479
487
combined_args = dedup (* args1 , * args2 )
480
488
if origin is Literal :
481
- return Literal [ * combined_args ]
489
+ return generate_literal_type ( * combined_args )
482
490
elif origin is Annotated :
483
491
# If two Annotated types have different set of ParamAnnotationType objects in metadata, treat them as
484
492
# different types as a union type. Otherwise merge them
@@ -492,7 +500,7 @@ def merge_args_per_origin(args: Sequence[Any]) -> tuple[Any, ...]:
492
500
) or not (annotation_types1 or annotation_types2 ):
493
501
combined_type = merge_annotation_types (get_args (tp1 )[0 ], get_args (tp2 )[0 ])
494
502
combined_metadata = dedup (* tp1 .__metadata__ , * tp2 .__metadata__ )
495
- return Annotated [ combined_type , * combined_metadata ]
503
+ return annotate_type ( combined_type , * combined_metadata )
496
504
else :
497
505
return generate_union_type ([tp1 , tp2 ])
498
506
elif origin is dict :
@@ -509,7 +517,7 @@ def merge_args_per_origin(args: Sequence[Any]) -> tuple[Any, ...]:
509
517
elif origin is list :
510
518
return list [generate_union_type (merge_args_per_origin (combined_args ))]
511
519
elif origin in [Union , UnionType ]:
512
- return Union [ * merge_args_per_origin (combined_args )]
520
+ return generate_union_type ( merge_args_per_origin (combined_args ))
513
521
514
522
# TODO: Needs improvements to cover more cases
515
523
if is_optional_type (tp1 ):
0 commit comments