Skip to content

Commit 4afec88

Browse files
committed
Add support of merging model field types as part of model merge
1 parent 12dae41 commit 4afec88

File tree

2 files changed

+150
-26
lines changed

2 files changed

+150
-26
lines changed

src/openapi_test_client/libraries/api/api_functions/utils/param_model.py

+38-15
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import defaultdict
55
from dataclasses import Field, field, make_dataclass
66
from functools import lru_cache
7-
from types import NoneType, UnionType
7+
from types import MappingProxyType, NoneType, UnionType
88
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union, cast, get_args, get_origin
99

1010
import inflect
@@ -35,20 +35,24 @@
3535
logger = get_logger(__name__)
3636

3737

38+
def is_param_model(annotated_type: Any) -> bool:
39+
"""Check if the givin annotated type is a custom parammodel
40+
:param annotated_type: Annotated type to check whether it is a param model or not
41+
"""
42+
return inspect.isclass(annotated_type) and issubclass(annotated_type, ParamModel)
43+
44+
3845
def has_param_model(annotated_type: Any) -> bool:
3946
"""Check if the given annotated type contains a custom param model
4047
4148
:param annotated_type: Annotated type for a field to check whether it contains a param model or not
4249
"""
4350

44-
def _is_param_model(obj: Any) -> bool:
45-
return inspect.isclass(obj) and issubclass(obj, ParamModel)
46-
4751
inner_type = param_type_util.get_inner_type(annotated_type)
4852
if param_type_util.is_union_type(inner_type):
49-
return any(_is_param_model(o) for o in get_args(inner_type))
53+
return any(is_param_model(o) for o in get_args(inner_type))
5054
else:
51-
return _is_param_model(inner_type)
55+
return is_param_model(inner_type)
5256

5357

5458
def get_param_model(annotated_type: Any) -> ParamModel | list[ParamModel] | None:
@@ -360,29 +364,48 @@ def _merge_models(models: list[type[ParamModel]]) -> type[ParamModel]:
360364
361365
- Model1
362366
@dataclass
363-
class MyModel:
367+
class MyModel(ParamModel):
364368
param_1: str = Unset
365369
param_2: int = Unset
370+
param_3: Literal["1", "2"] = Unset
366371
367372
- Model2
368373
@dataclass
369-
class MyModel:
374+
class MyModel(ParamModel):
370375
param_1: str = Unset
371-
param_3: int = Unset
376+
param_2: str = Unset
377+
param_3: Literal["2", "3"] = Unset
378+
param_4: int = Unset
372379
373380
- Merged model
374381
@dataclass
375-
class MyModel:
382+
class MyModel(ParamModel):
376383
param_1: str = Unset
377-
param_2: int = Unset
378-
param_3: int = Unset
379-
384+
param_2: int | str = Unset
385+
param_3: Literal["1", "2", "3"] = Unset
386+
param_4: int = Unset
380387
"""
381388
assert models
382389
assert len(set(m.__name__ for m in models)) == 1
383-
merged_dataclass_fields = {}
390+
merged_dataclass_fields: dict[str, Field] = {}
384391
for model in models:
385-
merged_dataclass_fields.update(model.__dataclass_fields__)
392+
for field_name, field_obj in model.__dataclass_fields__.items():
393+
if field_name in merged_dataclass_fields:
394+
merged_field_obj = merged_dataclass_fields[field_name]
395+
if merged_field_obj.type != field_obj.type:
396+
# merge field types and metadata
397+
merged_field_obj.type = param_type_util.merge_annotation_types(
398+
merged_field_obj.type, field_obj.type
399+
)
400+
if "anyOf" in merged_field_obj.metadata:
401+
merged_field_obj.metadata["anyOf"].append(field_obj.metadata)
402+
else:
403+
merged_field_obj.metadata = MappingProxyType(
404+
{"anyOf": [merged_field_obj.metadata, field_obj.metadata]}
405+
)
406+
else:
407+
merged_dataclass_fields[field_name] = field_obj
408+
386409
new_fields = [
387410
(field_name, field_obj.type, field(default=Unset, metadata=field_obj.metadata))
388411
for field_name, field_obj in merged_dataclass_fields.items()

src/openapi_test_client/libraries/api/api_functions/utils/param_type.py

+112-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from collections.abc import Sequence
33
from dataclasses import asdict
44
from functools import reduce
5-
from operator import or_
65
from types import NoneType, UnionType
76
from typing import (
87
Annotated,
@@ -18,7 +17,7 @@
1817
from common_libs.logging import get_logger
1918

2019
import openapi_test_client.libraries.api.api_functions.utils.param_model as param_model_util
21-
from openapi_test_client.libraries.api.types import Alias, Constraint, Format, ParamDef
20+
from openapi_test_client.libraries.api.types import Alias, Constraint, Format, ParamAnnotationType, ParamDef
2221
from openapi_test_client.libraries.common.constants import BACKSLASH
2322

2423
logger = get_logger(__name__)
@@ -291,8 +290,7 @@ def is_type_of(tp: str | Any, type_to_check: Any) -> bool:
291290
return is_type_of(tp.__origin__, type_to_check)
292291
elif is_union_type(tp):
293292
return any([is_type_of(x, type_to_check) for x in get_args(tp)])
294-
else:
295-
return tp is type_to_check
293+
return tp is type_to_check
296294

297295

298296
def is_optional_type(tp: Any) -> bool:
@@ -332,17 +330,12 @@ def is_deprecated_param(tp: Any) -> bool:
332330
return get_origin(tp) is Annotated and "deprecated" in tp.__metadata__
333331

334332

335-
def generate_union_type(type_annotations: Sequence[Any]) -> Any:
333+
def generate_union_type(type_annotations: Sequence[Any]) -> UnionType:
336334
"""Convert multiple annotations to a Union type
337335
338336
:param type_annotations: type annotations
339337
"""
340-
if len(set(repr(x) for x in type_annotations)) == 1:
341-
# All annotations are the same
342-
type_annotation = type_annotations[0]
343-
else:
344-
type_annotation = reduce(or_, type_annotations)
345-
return type_annotation
338+
return reduce(or_, type_annotations)
346339

347340

348341
def generate_optional_type(tp: Any) -> Any:
@@ -388,3 +381,111 @@ def get_annotated_type(tp: Any) -> _AnnotatedAlias | None:
388381
else:
389382
if get_origin(tp) is Annotated:
390383
return tp
384+
385+
386+
def merge_annotation_types(tp1: Any, tp2: Any) -> Any:
387+
"""Merge type annotations
388+
389+
:param tp1: annotated type1
390+
:param tp2: annotated type2
391+
392+
Note: This is still experimental
393+
"""
394+
395+
def dedup(*args: Any) -> tuple[Any, ...]:
396+
"""Deduplicate items by retaining the order"""
397+
seen = set()
398+
deduped_args = []
399+
for arg in args:
400+
if arg not in seen:
401+
deduped_args.append(arg)
402+
seen.add(arg)
403+
return tuple(deduped_args)
404+
405+
def merge_args_per_origin(args: Sequence[Any]) -> tuple[Any, ...]:
406+
"""Merge type annotations per its origiin type"""
407+
origin_type_order = {Literal: 1, Annotated: 2, Union: 3, UnionType: 4, list: 5, dict: 6, None: 10}
408+
args_per_origin = {}
409+
for arg in args:
410+
args_per_origin.setdefault(get_origin(arg), []).append(arg)
411+
return tuple(
412+
reduce(
413+
merge_annotation_types,
414+
sorted(args_, key=lambda x: origin_type_order.get(get_origin(x), 99)),
415+
)
416+
for args_ in args_per_origin.values()
417+
)
418+
419+
origin = get_origin(tp1)
420+
origin2 = get_origin(tp2)
421+
if origin or origin2:
422+
if origin == origin2:
423+
args1 = get_args(tp1)
424+
args2 = get_args(tp2)
425+
# stop using set here
426+
combined_args = dedup(*args1, *args2)
427+
if origin is Literal:
428+
return Literal[*combined_args]
429+
elif origin is Annotated:
430+
# If two Annotated types have different set of ParamAnnotationType objects in metadata, treat them as
431+
# different types as a union type. Otherwise merge them
432+
# TODO: revisit this part
433+
annotation_types1 = [x for x in tp1.__metadata__ if isinstance(x, ParamAnnotationType)]
434+
annotation_types2 = [x for x in tp2.__metadata__ if isinstance(x, ParamAnnotationType)]
435+
if (
436+
annotation_types1
437+
and annotation_types1
438+
and ([asdict(x) for x in annotation_types1] == [asdict(y) for y in annotation_types2])
439+
) or not (annotation_types1 or annotation_types2):
440+
combined_type = merge_annotation_types(get_args(tp1)[0], get_args(tp2)[0])
441+
combined_metadata = dedup(*tp1.__metadata__, *tp2.__metadata__)
442+
return Annotated[combined_type, *combined_metadata]
443+
else:
444+
return generate_union_type([tp1, tp2])
445+
elif origin is dict:
446+
key_type, val_type = args1
447+
key_type2, val_type2 = args2
448+
if key_type == key_type2:
449+
if val_type == val_type2:
450+
return dict[key_type, val_type]
451+
else:
452+
return dict[key_type, merge_annotation_types(val_type, val_type2)]
453+
else:
454+
if val_type == val_type2:
455+
return dict[generate_union_type((key_type, key_type2)), val_type]
456+
elif origin is list:
457+
return list[generate_union_type(merge_args_per_origin(combined_args))]
458+
elif origin in [Union, UnionType]:
459+
return Union[*merge_args_per_origin(combined_args)]
460+
return generate_union_type((tp1, tp2))
461+
462+
463+
def or_(x: Any, y: Any) -> Any:
464+
"""Customized version of operator.or_ that treats our dynamically created param model classes with the same
465+
name as duplicates
466+
467+
eg. operator.or_ v.s our or_
468+
>>> import operator
469+
>>> from openapi_test_client.libraries.api.types import ParamModel
470+
>>> Model1 = type("MyModel", (ParamModel,), {})
471+
>>> Model2 = type("MyModel", (ParamModel,), {})
472+
>>> reduce(operator.or_, [Model1 | None, Model2])
473+
__main__.MyModel | None | __main__.MyModel
474+
>>> reduce(or_, [Model1 | None, Model2])
475+
__main__.MyModel | None
476+
"""
477+
if param_model_util.is_param_model(x) and param_model_util.is_param_model(y) and x.__name__ == y.__name__:
478+
return x
479+
else:
480+
is_x_union = is_union_type(x)
481+
is_y_union = is_union_type(y)
482+
if is_x_union:
483+
if is_y_union:
484+
return reduce(or_, (*get_args(x), *get_args(y)))
485+
elif param_model_util.is_param_model(y):
486+
param_model_names_in_x = [x.__name__ for x in get_args(x) if param_model_util.is_param_model(x)]
487+
if y.__name__ in param_model_names_in_x:
488+
return x
489+
elif is_y_union:
490+
return reduce(or_, (x, *get_args(y)))
491+
return x | y

0 commit comments

Comments
 (0)