Skip to content

Commit ebbf9ae

Browse files
committed
Change the logic to generate a Literal type annotation to bypass the caching mechanism of typing module
1 parent 0bbf2e8 commit ebbf9ae

File tree

4 files changed

+115
-28
lines changed

4 files changed

+115
-28
lines changed

src/openapi_test_client/libraries/api/api_functions/utils/param_type.py

+33-25
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,20 @@
1010
from common_libs.logging import get_logger
1111

1212
import openapi_test_client.libraries.api.api_functions.utils.param_model as param_model_util
13-
from openapi_test_client.libraries.api.types import Alias, Constraint, Format, ParamAnnotationType, ParamDef
13+
from openapi_test_client.libraries.api.types import (
14+
Alias,
15+
Constraint,
16+
Format,
17+
ParamAnnotationType,
18+
ParamDef,
19+
UncacheableLiteralArg,
20+
)
21+
from openapi_test_client.libraries.api.types import Optional as Optional_
1422
from openapi_test_client.libraries.common.constants import BACKSLASH
1523
from openapi_test_client.libraries.common.misc import dedup
1624

1725
if TYPE_CHECKING:
18-
from typing import _AnnotatedAlias # type: ignore
26+
from typing import _AnnotatedAlias, _LiteralGenericAlias # type: ignore[attr-defined]
1927

2028

2129
logger = get_logger(__name__)
@@ -161,7 +169,7 @@ def resolve(param_type: str, param_format: str | None = None) -> Any:
161169
)
162170
else:
163171
if enum := param_def.get("enum"):
164-
type_annotation = Literal[*enum]
172+
type_annotation = generate_literal_type(*enum)
165173
elif isinstance(param_def, ParamDef.UnknownType):
166174
logger.warning(
167175
f"Param '{param_name}': Unable to locate a parameter type in the following parameter object. "
@@ -248,11 +256,11 @@ def replace_base_type(tp: Any, new_type: Any, replace_container_type: bool = Fal
248256
args = get_args(tp)
249257
if is_union_type(tp):
250258
if is_optional_type(tp):
251-
return Optional[replace_base_type(args[0], new_type)] # noqa: UP007
259+
return generate_optional_type(replace_base_type(args[0], new_type))
252260
else:
253261
return replace_base_type(args, new_type)
254262
elif origin_type is Annotated:
255-
return Annotated[replace_base_type(tp.__origin__, new_type), *tp.__metadata__]
263+
return annotate_type(replace_base_type(tp.__origin__, new_type), *tp.__metadata__)
256264
elif origin_type in [list, tuple]:
257265
if replace_container_type:
258266
return new_type
@@ -290,7 +298,7 @@ def is_type_of(param_type: str | Any, type_to_check: Any) -> bool:
290298
# Add if needed
291299
raise NotImplementedError
292300
elif origin_type := get_origin(param_type):
293-
if origin_type is type_to_check:
301+
if (origin_type is type_to_check) or (origin_type is Union and type_to_check in [Optional, Optional_]):
294302
return True
295303
elif origin_type is Annotated:
296304
return is_type_of(param_type.__origin__, type_to_check)
@@ -345,23 +353,23 @@ def generate_union_type(type_annotations: Sequence[Any]) -> Any:
345353

346354

347355
def generate_optional_type(tp: Any) -> Any:
348-
"""Convert the type annotation to Optional[tp]
349-
350-
Wrap the type with `Optional[]`, but using `Union` with None instead as there seems to be a cache issue
351-
where `Optional[Literal['val1', 'val2']]` with change the order of Literal parameters due to the cache.
352-
353-
eg. The issue seen in Python 3.11
354-
>>> t1 = Literal["foo", "bar"]
355-
>>> Optional[t1]
356-
typing.Optional[typing.Literal['foo', 'bar']]
357-
>>> t2 = Literal["bar", "foo"]
358-
>>> Optional[t2]
359-
typing.Optional[typing.Literal['foo', 'bar']] <--- HERE
360-
"""
356+
"""Convert the type annotation to Optional[tp]"""
361357
if is_optional_type(tp):
362358
return tp
363359
else:
364-
return Union[tp, None] # noqa: UP007
360+
return Optional[tp] # noqa: UP007
361+
362+
363+
def generate_literal_type(*args: Any, uncacheable: bool = True) -> _LiteralGenericAlias:
364+
"""Generate a Literal type annotation using given args
365+
366+
:param args: Literal args
367+
:param uncacheable: Make this Literal type uncacheable
368+
"""
369+
if uncacheable:
370+
cacheable_args = tuple(arg.obj if isinstance(arg, UncacheableLiteralArg) else arg for arg in args)
371+
args = tuple(UncacheableLiteralArg(arg) for arg in dedup(*cacheable_args))
372+
return Literal[*args]
365373

366374

367375
def annotate_type(tp: Any, *metadata: Any) -> Any:
@@ -378,7 +386,7 @@ def annotate_type(tp: Any, *metadata: Any) -> Any:
378386
return modify_annotated_metadata(tp, *metadata, action="add")
379387
elif is_optional_type(tp):
380388
inner_type = generate_union_type([x for x in get_args(tp) if x is not NoneType])
381-
return Optional[annotate_type(inner_type, *metadata)] # noqa: UP007
389+
return generate_optional_type(annotate_type(inner_type, *metadata))
382390
else:
383391
return Annotated[tp, *metadata]
384392

@@ -411,7 +419,7 @@ def modify_metadata(tp: Any) -> Any:
411419
raise ValueError("At least one metadata must exist after the action is performed")
412420
else:
413421
new_metadata = dedup(*metadata)
414-
return Annotated[get_args(tp)[0], *new_metadata]
422+
return annotate_type(get_args(tp)[0], *new_metadata)
415423
else:
416424
if is_union_type(tp):
417425
return generate_union_type([modify_metadata(arg) for arg in get_args(tp)])
@@ -478,7 +486,7 @@ def merge_args_per_origin(args: Sequence[Any]) -> tuple[Any, ...]:
478486
# stop using set here
479487
combined_args = dedup(*args1, *args2)
480488
if origin is Literal:
481-
return Literal[*combined_args]
489+
return generate_literal_type(*combined_args)
482490
elif origin is Annotated:
483491
# If two Annotated types have different set of ParamAnnotationType objects in metadata, treat them as
484492
# different types as a union type. Otherwise merge them
@@ -492,7 +500,7 @@ def merge_args_per_origin(args: Sequence[Any]) -> tuple[Any, ...]:
492500
) or not (annotation_types1 or annotation_types2):
493501
combined_type = merge_annotation_types(get_args(tp1)[0], get_args(tp2)[0])
494502
combined_metadata = dedup(*tp1.__metadata__, *tp2.__metadata__)
495-
return Annotated[combined_type, *combined_metadata]
503+
return annotate_type(combined_type, *combined_metadata)
496504
else:
497505
return generate_union_type([tp1, tp2])
498506
elif origin is dict:
@@ -509,7 +517,7 @@ def merge_args_per_origin(args: Sequence[Any]) -> tuple[Any, ...]:
509517
elif origin is list:
510518
return list[generate_union_type(merge_args_per_origin(combined_args))]
511519
elif origin in [Union, UnionType]:
512-
return Union[*merge_args_per_origin(combined_args)]
520+
return generate_union_type(merge_args_per_origin(combined_args))
513521

514522
# TODO: Needs improvements to cover more cases
515523
if is_optional_type(tp1):

src/openapi_test_client/libraries/api/types.py

+44
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,47 @@ class Constraint(ParamAnnotationType):
485485
# but Pydantic currently treats them as an integer
486486
exclusive_minimum: int | None = None
487487
exclusive_maximum: int | None = None
488+
489+
490+
class UncacheableLiteralArg:
491+
"""Make args for typing.Literal uncacheable
492+
493+
Due to the default cache mechanism implemented in the typing module, the order of arguments for the generated
494+
Literal type annotation can be unexpected if there's a cache. This behavior causes our dynamic code generation
495+
unstable if API specs define more than one param objects that have the exact same enum values but in different
496+
orders. Wrapping each Literal arg value with this class ensures the cached behavior will not happen during the code
497+
generation.
498+
499+
eg.
500+
1. The default behavior of typing module with cache
501+
>>> from typing import Literal, Optional
502+
>>> t1 = Literal["foo", "bar"]
503+
>>> Optional[t1]
504+
typing.Optional[typing.Literal['foo', 'bar']]
505+
>>> t2 = Literal["bar", "foo"]
506+
>>> Optional[t2]
507+
typing.Optional[typing.Literal['foo', 'bar']] <--- HERE (Unexpected order due to the cached result)
508+
509+
2. Uncached behavior
510+
>>> t1 = Literal[UncacheableLiteralArg("foo"), UncacheableLiteralArg("bar")]
511+
>>> Optional[t1]
512+
typing.Optional[typing.Literal['foo', 'bar']]
513+
>>> t2 = Literal[UncacheableLiteralArg("bar"), UncacheableLiteralArg("foo")]
514+
>>> Optional[t2]
515+
typing.Optional[typing.Literal['bar', 'foo']] <--- HERE (Expected order)
516+
"""
517+
518+
def __init__(self, obj: Any):
519+
self.obj = obj
520+
521+
def __repr__(self) -> str:
522+
return repr(self.obj)
523+
524+
def __eq__(self, other: Any) -> bool:
525+
if isinstance(other, UncacheableLiteralArg):
526+
return self.obj == other.obj
527+
else:
528+
return self.obj == other
529+
530+
def __hash__(self) -> int:
531+
return id(self)

tests/integration/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,5 @@ def temp_app_client(
149149

150150

151151
@pytest.fixture(autouse=True)
152-
def _steram_cmd_output(request: FixtureRequest):
152+
def _steram_cmd_output(request: FixtureRequest) -> None:
153153
os.environ["IS_CAPTURING_OUTPUT"] = str(request.config.option.capture != "no").lower()

tests/unit/test_param_type.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,15 @@
66
import pytest
77

88
import openapi_test_client.libraries.api.api_functions.utils.param_type as param_type_util
9-
from openapi_test_client.libraries.api.types import Alias, Constraint, Format, Optional, ParamModel, Unset
9+
from openapi_test_client.libraries.api.types import (
10+
Alias,
11+
Constraint,
12+
Format,
13+
Optional,
14+
ParamModel,
15+
UncacheableLiteralArg,
16+
Unset,
17+
)
1018

1119

1220
class MyClass: ...
@@ -37,7 +45,9 @@ class MyParamModel(ParamModel):
3745
(dict, "dict"),
3846
(dict[str, Any], "dict[str, Any]"),
3947
(Literal[None], "Literal[None]"),
48+
(Literal[UncacheableLiteralArg(None)], "Literal[None]"),
4049
(Literal["1", "2"], "Literal['1', '2']"),
50+
(Literal[UncacheableLiteralArg("1"), UncacheableLiteralArg("2")], "Literal['1', '2']"),
4151
(MyClass, MyClass.__name__),
4252
(MyParamModel, MyParamModel.__name__),
4353
(ForwardRef(MyParamModel.__name__), MyParamModel.__name__),
@@ -182,6 +192,9 @@ def test_replace_baser_type(tp: Any, replace_with: Any, expected_type: Any) -> N
182192
(list[str], str, False),
183193
(Annotated[list[str], "meta"], str, False),
184194
(Optional[Annotated[list[str], "meta"]], str, False),
195+
(Optional[Annotated[Literal[1], "meta"]], Optional, True),
196+
(Optional[Annotated[Literal[1], "meta"]], Annotated, True),
197+
(Optional[Annotated[Literal[1], "meta"]], Literal, True),
185198
],
186199
)
187200
def test_is_type_of(param_type: Any, type_to_check: Any, is_type_of: bool) -> None:
@@ -293,6 +306,24 @@ def test_generate_optional_type(tp: Any, expected_type: Any) -> None:
293306
assert param_type_util.generate_optional_type(tp) == expected_type
294307

295308

309+
@pytest.mark.parametrize("uncacheable", [True, False])
310+
def test_generate_literal_type(uncacheable: bool) -> None:
311+
"""Verify that a literal type annotation can be generated with/without the typing module's caching mechanism"""
312+
args1 = ["1", "2", "2", "3"]
313+
args2 = ["3", "2", "1", "2"]
314+
tp1 = param_type_util.generate_literal_type(*args1, uncacheable=uncacheable)
315+
tp2 = param_type_util.generate_literal_type(*args2, uncacheable=uncacheable)
316+
if uncacheable:
317+
assert tp1 != tp2
318+
assert repr(Optional[tp1]) == "typing.Optional[typing.Literal['1', '2', '3']]"
319+
assert repr(Optional[tp2]) == "typing.Optional[typing.Literal['3', '2', '1']]"
320+
else:
321+
# NOTE: This is the default behavior of typing.Literal
322+
assert Literal[*args1] == tp1 == tp2
323+
assert repr(Optional[tp1]) == "typing.Optional[typing.Literal['1', '2', '3']]"
324+
assert repr(Optional[tp2]) == "typing.Optional[typing.Literal['1', '2', '3']]"
325+
326+
296327
@pytest.mark.parametrize(
297328
("tp", "metadata", "expected_type"),
298329
[
@@ -451,7 +482,11 @@ def test_get_annotated_type(tp: Any, annotated_type: Any) -> None:
451482
)
452483
def test_merge_annotation_types(tp1: Any, tp2: Any, expected_type: Any) -> None:
453484
"""Verify that two annotation types acn be merged"""
454-
assert param_type_util.merge_annotation_types(tp1, tp2) == expected_type
485+
if get_origin(tp1) is Literal or get_origin(tp2) is Literal:
486+
assert param_type_util.merge_annotation_types(tp1, tp2) != expected_type
487+
assert repr(param_type_util.merge_annotation_types(tp1, tp2)) == repr(expected_type)
488+
else:
489+
assert param_type_util.merge_annotation_types(tp1, tp2) == expected_type
455490

456491

457492
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)