|
2 | 2 | from collections.abc import Sequence
|
3 | 3 | from dataclasses import asdict
|
4 | 4 | from functools import reduce
|
5 |
| -from operator import or_ |
6 | 5 | from types import NoneType, UnionType
|
7 | 6 | from typing import (
|
8 | 7 | Annotated,
|
|
18 | 17 | from common_libs.logging import get_logger
|
19 | 18 |
|
20 | 19 | import openapi_test_client.libraries.api.api_functions.utils.param_model as param_model_util
|
21 |
| -from openapi_test_client.libraries.api.types import Alias, Constraint, Format, ParamDef |
| 20 | +from openapi_test_client.libraries.api.types import Alias, Constraint, Format, ParamAnnotationType, ParamDef |
22 | 21 | from openapi_test_client.libraries.common.constants import BACKSLASH
|
23 | 22 |
|
24 | 23 | logger = get_logger(__name__)
|
@@ -291,8 +290,7 @@ def is_type_of(tp: str | Any, type_to_check: Any) -> bool:
|
291 | 290 | return is_type_of(tp.__origin__, type_to_check)
|
292 | 291 | elif is_union_type(tp):
|
293 | 292 | return any([is_type_of(x, type_to_check) for x in get_args(tp)])
|
294 |
| - else: |
295 |
| - return tp is type_to_check |
| 293 | + return tp is type_to_check |
296 | 294 |
|
297 | 295 |
|
298 | 296 | def is_optional_type(tp: Any) -> bool:
|
@@ -332,17 +330,12 @@ def is_deprecated_param(tp: Any) -> bool:
|
332 | 330 | return get_origin(tp) is Annotated and "deprecated" in tp.__metadata__
|
333 | 331 |
|
334 | 332 |
|
335 |
| -def generate_union_type(type_annotations: Sequence[Any]) -> Any: |
| 333 | +def generate_union_type(type_annotations: Sequence[Any]) -> UnionType: |
336 | 334 | """Convert multiple annotations to a Union type
|
337 | 335 |
|
338 | 336 | :param type_annotations: type annotations
|
339 | 337 | """
|
340 |
| - if len(set(repr(x) for x in type_annotations)) == 1: |
341 |
| - # All annotations are the same |
342 |
| - type_annotation = type_annotations[0] |
343 |
| - else: |
344 |
| - type_annotation = reduce(or_, type_annotations) |
345 |
| - return type_annotation |
| 338 | + return reduce(or_, type_annotations) |
346 | 339 |
|
347 | 340 |
|
348 | 341 | def generate_optional_type(tp: Any) -> Any:
|
@@ -388,3 +381,111 @@ def get_annotated_type(tp: Any) -> _AnnotatedAlias | None:
|
388 | 381 | else:
|
389 | 382 | if get_origin(tp) is Annotated:
|
390 | 383 | return tp
|
| 384 | + |
| 385 | + |
| 386 | +def merge_annotation_types(tp1: Any, tp2: Any) -> Any: |
| 387 | + """Merge type annotations |
| 388 | +
|
| 389 | + :param tp1: annotated type1 |
| 390 | + :param tp2: annotated type2 |
| 391 | +
|
| 392 | + Note: This is still experimental |
| 393 | + """ |
| 394 | + |
| 395 | + def dedup(*args: Any) -> tuple[Any, ...]: |
| 396 | + """Deduplicate items by retaining the order""" |
| 397 | + seen = set() |
| 398 | + deduped_args = [] |
| 399 | + for arg in args: |
| 400 | + if arg not in seen: |
| 401 | + deduped_args.append(arg) |
| 402 | + seen.add(arg) |
| 403 | + return tuple(deduped_args) |
| 404 | + |
| 405 | + def merge_args_per_origin(args: Sequence[Any]) -> tuple[Any, ...]: |
| 406 | + """Merge type annotations per its origiin type""" |
| 407 | + origin_type_order = {Literal: 1, Annotated: 2, Union: 3, UnionType: 4, list: 5, dict: 6, None: 10} |
| 408 | + args_per_origin = {} |
| 409 | + for arg in args: |
| 410 | + args_per_origin.setdefault(get_origin(arg), []).append(arg) |
| 411 | + return tuple( |
| 412 | + reduce( |
| 413 | + merge_annotation_types, |
| 414 | + sorted(args_, key=lambda x: origin_type_order.get(get_origin(x), 99)), |
| 415 | + ) |
| 416 | + for args_ in args_per_origin.values() |
| 417 | + ) |
| 418 | + |
| 419 | + origin = get_origin(tp1) |
| 420 | + origin2 = get_origin(tp2) |
| 421 | + if origin or origin2: |
| 422 | + if origin == origin2: |
| 423 | + args1 = get_args(tp1) |
| 424 | + args2 = get_args(tp2) |
| 425 | + # stop using set here |
| 426 | + combined_args = dedup(*args1, *args2) |
| 427 | + if origin is Literal: |
| 428 | + return Literal[*combined_args] |
| 429 | + elif origin is Annotated: |
| 430 | + # If two Annotated types have different set of ParamAnnotationType objects in metadata, treat them as |
| 431 | + # different types as a union type. Otherwise merge them |
| 432 | + # TODO: revisit this part |
| 433 | + annotation_types1 = [x for x in tp1.__metadata__ if isinstance(x, ParamAnnotationType)] |
| 434 | + annotation_types2 = [x for x in tp2.__metadata__ if isinstance(x, ParamAnnotationType)] |
| 435 | + if ( |
| 436 | + annotation_types1 |
| 437 | + and annotation_types1 |
| 438 | + and ([asdict(x) for x in annotation_types1] == [asdict(y) for y in annotation_types2]) |
| 439 | + ) or not (annotation_types1 or annotation_types2): |
| 440 | + combined_type = merge_annotation_types(get_args(tp1)[0], get_args(tp2)[0]) |
| 441 | + combined_metadata = dedup(*tp1.__metadata__, *tp2.__metadata__) |
| 442 | + return Annotated[combined_type, *combined_metadata] |
| 443 | + else: |
| 444 | + return generate_union_type([tp1, tp2]) |
| 445 | + elif origin is dict: |
| 446 | + key_type, val_type = args1 |
| 447 | + key_type2, val_type2 = args2 |
| 448 | + if key_type == key_type2: |
| 449 | + if val_type == val_type2: |
| 450 | + return dict[key_type, val_type] |
| 451 | + else: |
| 452 | + return dict[key_type, merge_annotation_types(val_type, val_type2)] |
| 453 | + else: |
| 454 | + if val_type == val_type2: |
| 455 | + return dict[generate_union_type((key_type, key_type2)), val_type] |
| 456 | + elif origin is list: |
| 457 | + return list[generate_union_type(merge_args_per_origin(combined_args))] |
| 458 | + elif origin in [Union, UnionType]: |
| 459 | + return Union[*merge_args_per_origin(combined_args)] |
| 460 | + return generate_union_type((tp1, tp2)) |
| 461 | + |
| 462 | + |
| 463 | +def or_(x: Any, y: Any) -> Any: |
| 464 | + """Customized version of operator.or_ that treats our dynamically created param model classes with the same |
| 465 | + name as duplicates |
| 466 | +
|
| 467 | + eg. operator.or_ v.s our or_ |
| 468 | + >>> import operator |
| 469 | + >>> from openapi_test_client.libraries.api.types import ParamModel |
| 470 | + >>> Model1 = type("MyModel", (ParamModel,), {}) |
| 471 | + >>> Model2 = type("MyModel", (ParamModel,), {}) |
| 472 | + >>> reduce(operator.or_, [Model1 | None, Model2]) |
| 473 | + __main__.MyModel | None | __main__.MyModel |
| 474 | + >>> reduce(or_, [Model1 | None, Model2]) |
| 475 | + __main__.MyModel | None |
| 476 | + """ |
| 477 | + if param_model_util.is_param_model(x) and param_model_util.is_param_model(y) and x.__name__ == y.__name__: |
| 478 | + return x |
| 479 | + else: |
| 480 | + is_x_union = is_union_type(x) |
| 481 | + is_y_union = is_union_type(y) |
| 482 | + if is_x_union: |
| 483 | + if is_y_union: |
| 484 | + return reduce(or_, (*get_args(x), *get_args(y))) |
| 485 | + elif param_model_util.is_param_model(y): |
| 486 | + param_model_names_in_x = [x.__name__ for x in get_args(x) if param_model_util.is_param_model(x)] |
| 487 | + if y.__name__ in param_model_names_in_x: |
| 488 | + return x |
| 489 | + elif is_y_union: |
| 490 | + return reduce(or_, (x, *get_args(y))) |
| 491 | + return x | y |
0 commit comments