Skip to content

Commit 12dae41

Browse files
committed
Refactor endpoint model creation
1 parent 13c6619 commit 12dae41

File tree

4 files changed

+78
-56
lines changed

4 files changed

+78
-56
lines changed

src/openapi_test_client/libraries/api/api_functions/utils/endpoint_model.py

+44-33
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
import inspect
44
import json
55
import re
6+
from collections.abc import Mapping, Sequence
67
from copy import deepcopy
7-
from dataclasses import MISSING, Field, field, make_dataclass
8+
from dataclasses import MISSING, field, make_dataclass
89
from typing import TYPE_CHECKING, Any, cast
910

1011
from common_libs.logging import get_logger
1112

1213
from openapi_test_client.libraries.api.api_functions.utils import param_model as param_model_util
1314
from openapi_test_client.libraries.api.api_functions.utils import param_type as param_type_util
14-
from openapi_test_client.libraries.api.types import EndpointModel, File, ParamDef, Unset
15+
from openapi_test_client.libraries.api.types import DataclassModelField, EndpointModel, File, ParamDef, Unset
1516

1617
if TYPE_CHECKING:
1718
from openapi_test_client.libraries.api import EndpointFunc
@@ -27,8 +28,8 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
2728
:param api_spec: Create a model from the OpenAPI spec. Otherwise the model be created from the existing endpoint
2829
function signatures
2930
"""
30-
path_param_fields = []
31-
body_or_query_param_fields = []
31+
path_param_fields: list[DataclassModelField] = []
32+
body_or_query_param_fields: list[DataclassModelField] = []
3233
model_name = f"{type(endpoint_func).__name__.replace('EndpointFunc', EndpointModel.__name__)}"
3334
content_type = None
3435
if api_spec:
@@ -48,11 +49,10 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
4849
continue
4950
elif param_obj.default == inspect.Parameter.empty:
5051
# Positional arguments (path parameters)
51-
path_param_fields.append((name, param_obj.annotation))
52+
path_param_fields.append(DataclassModelField(name, param_obj.annotation))
5253
else:
5354
# keyword arguments (body/query parameters)
54-
param_field = (name, param_obj.annotation, field(default=Unset))
55-
body_or_query_param_fields.append(param_field)
55+
_add_body_or_query_param_field(body_or_query_param_fields, name, param_obj.annotation)
5656

5757
if hasattr(endpoint_func, "endpoint"):
5858
method = endpoint_func.endpoint.method
@@ -64,13 +64,13 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
6464
# Some OpenAPI specs don't properly document path parameters at all, or path parameters could be documented
6565
# as incorrect "in" like "query". We fix this by adding the missing path parameters, and remove them from
6666
# body/query params if any
67-
path_param_fields = [(x, str) for x in expected_path_params]
67+
path_param_fields = [DataclassModelField(x, str) for x in expected_path_params]
6868
body_or_query_param_fields = [x for x in body_or_query_param_fields if x[0] not in expected_path_params]
6969

7070
# Address the case where a path param name conflicts with body/query param name
71-
for i, (field_name, field_type) in enumerate(path_param_fields):
71+
for i, (field_name, field_type, _) in enumerate(path_param_fields):
7272
if field_name in [x[0] for x in body_or_query_param_fields]:
73-
path_param_fields[i] = (f"{field_name}_", field_type)
73+
path_param_fields[i] = DataclassModelField(f"{field_name}_", field_type)
7474

7575
# Some OpenAPI specs define a parameter name using characters we can't use as a python variable name.
7676
# We will use the cleaned name as the model field and annotate it as `Annotated[field_type, Alias(<original_val>)]`
@@ -83,7 +83,7 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
8383
type[EndpointModel],
8484
make_dataclass(
8585
model_name,
86-
fields,
86+
fields, # type: ignore
8787
bases=(EndpointModel,),
8888
namespace={"content_type": content_type, "endpoint_func": endpoint_func},
8989
kw_only=True,
@@ -130,8 +130,8 @@ def generate_func_signature_in_str(model: type[EndpointModel]) -> str:
130130
def _parse_parameter_objects(
131131
method: str,
132132
parameter_objects: list[dict[str, Any]],
133-
path_param_fields: list[tuple[str, Any]],
134-
body_or_query_param_fields: list[tuple[str, Any, Field]],
133+
path_param_fields: list[DataclassModelField],
134+
body_or_query_param_fields: list[DataclassModelField],
135135
):
136136
"""Parse parameter objects
137137
@@ -148,12 +148,13 @@ def _parse_parameter_objects(
148148
param_type_annotation = param_type_util.resolve_type_annotation(
149149
param_name, param_def, _is_required=is_required
150150
)
151-
152151
if param_location in ["header", "cookies"]:
153152
# We currently don't support these
154153
continue
155154
elif param_location == "path":
156-
path_param_fields.append((param_name, param_type_annotation))
155+
if param_name not in [x[0] for x in path_param_fields]:
156+
# Handle duplicates. Some API specs incorrectly document duplicated parameters
157+
path_param_fields.append(DataclassModelField(param_name, param_type_annotation))
157158
elif param_location == "query":
158159
if method.upper() != "GET":
159160
# Annotate query params for non GET endpoints
@@ -180,19 +181,14 @@ def _parse_parameter_objects(
180181
method, parameter_objects, path_param_fields, body_or_query_param_fields
181182
)
182183
else:
183-
if param_name not in [x[0] for x in body_or_query_param_fields]:
184-
body_or_query_param_fields.append(
185-
(
186-
param_name,
187-
param_type_annotation,
188-
field(default=Unset, metadata=param_obj),
189-
)
190-
)
191-
else:
192-
if param_name not in [x[0] for x in body_or_query_param_fields]:
193-
body_or_query_param_fields.append(
194-
(param_name, param_type_annotation, field(default=Unset, metadata=param_obj))
184+
_add_body_or_query_param_field(
185+
body_or_query_param_fields, param_name, param_type_annotation, param_obj=param_obj
195186
)
187+
188+
else:
189+
_add_body_or_query_param_field(
190+
body_or_query_param_fields, param_name, param_type_annotation, param_obj=param_obj
191+
)
196192
else:
197193
raise NotImplementedError(f"Unsupported param 'in': {param_location}")
198194
except Exception:
@@ -205,7 +201,7 @@ def _parse_parameter_objects(
205201

206202

207203
def _parse_request_body_object(
208-
request_body_obj: dict[str, Any], body_or_query_param_fields: list[tuple[str, Any, Field]]
204+
request_body_obj: dict[str, Any], body_or_query_param_fields: list[DataclassModelField]
209205
) -> str | None:
210206
"""Parse request body object
211207
@@ -250,7 +246,9 @@ def parse_schema_obj(obj: dict[str, Any]):
250246
param_type = File
251247
if not param_def.is_required:
252248
param_type = param_type | None
253-
body_or_query_param_fields.append((param_name, param_type, field(default=Unset)))
249+
_add_body_or_query_param_field(
250+
body_or_query_param_fields, param_name, param_type, param_obj=param_obj
251+
)
254252
else:
255253
existing_param_names = [x[0] for x in body_or_query_param_fields]
256254
if param_name in existing_param_names:
@@ -259,16 +257,17 @@ def parse_schema_obj(obj: dict[str, Any]):
259257
for _, t, m in duplicated_param_fields:
260258
param_type_annotations.append(t)
261259
param_type_annotation = param_type_util.generate_union_type(param_type_annotations)
262-
merged_param_field = (
260+
merged_param_field = DataclassModelField(
263261
param_name,
264262
param_type_annotation,
265-
field(default=Unset, metadata=param_obj),
263+
default=field(default=Unset, metadata=param_obj),
266264
)
267265
body_or_query_param_fields[existing_param_names.index(param_name)] = merged_param_field
268266
else:
269267
param_type_annotation = param_type_util.resolve_type_annotation(param_name, param_def)
270-
param_field = (param_name, param_type_annotation, field(default=Unset, metadata=param_obj))
271-
body_or_query_param_fields.append(param_field)
268+
_add_body_or_query_param_field(
269+
body_or_query_param_fields, param_name, param_type_annotation, param_obj=param_obj
270+
)
272271
except Exception:
273272
logger.error(
274273
"Encountered an error while processing the param object in 'requestBody':\n"
@@ -283,6 +282,18 @@ def parse_schema_obj(obj: dict[str, Any]):
283282
return content_type
284283

285284

285+
def _add_body_or_query_param_field(
286+
param_fields: list[DataclassModelField],
287+
param_name: str,
288+
param_type_annotation: Any,
289+
param_obj: Mapping[str, Any] | dict[str, Any] | Sequence[dict[str, Any]] | None = None,
290+
):
291+
if param_name not in [x[0] for x in param_fields]:
292+
param_fields.append(
293+
DataclassModelField(param_name, param_type_annotation, default=field(default=Unset, metadata=param_obj))
294+
)
295+
296+
286297
def _is_file_param(
287298
content_type: str,
288299
param_def: ParamDef | ParamDef.ParamGroup | ParamDef.UnknownType,

src/openapi_test_client/libraries/api/api_functions/utils/param_model.py

+24-21
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from openapi_test_client.libraries.api.types import (
1818
Alias,
1919
DataclassModel,
20+
DataclassModelField,
2021
EndpointModel,
2122
File,
2223
ParamAnnotationType,
@@ -117,15 +118,22 @@ def create_model_from_param_def(
117118
return _merge_models([create_model_from_param_def(model_name, p) for p in param_def])
118119
else:
119120
fields = [
120-
(
121+
DataclassModelField(
121122
inner_param_name,
122123
param_type_util.resolve_type_annotation(inner_param_name, ParamDef.from_param_obj(inner_param_obj)),
123-
field(default=Unset, metadata=inner_param_obj),
124+
default=field(default=Unset, metadata=inner_param_obj),
124125
)
125126
for inner_param_name, inner_param_obj in param_def.get("properties", {}).items()
126127
]
127128
alias_illegal_model_field_names(fields)
128-
return cast(type[ParamModel], make_dataclass(model_name, fields, bases=(ParamModel,)))
129+
return cast(
130+
type[ParamModel],
131+
make_dataclass(
132+
model_name,
133+
fields, # type: ignore
134+
bases=(ParamModel,),
135+
),
136+
)
129137

130138

131139
def generate_imports_code_from_model(
@@ -286,10 +294,10 @@ def visit(model_name: str):
286294
return sorted(models, key=lambda x: sorted_models_names.index(x.__name__))
287295

288296

289-
def alias_illegal_model_field_names(param_fields: list[tuple[str, Any] | tuple[str, Any, Field]]):
297+
def alias_illegal_model_field_names(model_fields: list[DataclassModelField]):
290298
"""Clean illegal model field name and annotate the field type with Alias class
291299
292-
:param param_fields: fields value to be passed to make_dataclass()
300+
:param model_fields: fields value to be passed to make_dataclass()
293301
"""
294302

295303
def make_alias(name: str, param_type: Any) -> str:
@@ -330,22 +338,17 @@ def make_alias(name: str, param_type: Any) -> str:
330338
name += "_"
331339
return name
332340

333-
if param_fields:
334-
for i, param_field in enumerate(param_fields):
335-
if len(param_field) == 2:
336-
# path parameters
337-
field_name, field_type = param_field
338-
field_obj = object
339-
else:
340-
# body or query parameters
341-
field_name, field_type, field_obj = param_field
342-
343-
if (alias_name := make_alias(field_name, field_type)) != field_name:
344-
if isinstance(field_obj, Field) and field_obj.metadata:
345-
logger.warning(f"Converted parameter name '{field_name}' to '{alias_name}'")
346-
new_fields = [alias_name, param_type_util.generate_annotated_type(field_type, Alias(field_name))]
347-
new_fields.append(field_obj)
348-
param_fields[i] = tuple(new_fields)
341+
if model_fields:
342+
for i, model_field in enumerate(model_fields):
343+
if (alias_name := make_alias(model_field.name, model_field.type)) != model_field.name:
344+
if isinstance(model_field.default, Field) and model_field.default.metadata:
345+
logger.warning(f"Converted parameter name '{model_field.name}' to '{alias_name}'")
346+
new_fields = (
347+
alias_name,
348+
param_type_util.generate_annotated_type(model_field.type, Alias(model_field.name)),
349+
model_field.default,
350+
)
351+
model_fields[i] = DataclassModelField(*new_fields)
349352

350353

351354
def _merge_models(models: list[type[ParamModel]]) -> type[ParamModel]:

src/openapi_test_client/libraries/api/api_spec.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def has_reference(obj: Any) -> bool:
116116
return "'$ref':" in str(obj)
117117

118118
def resolve_recursive(reference: Any, schemas_seen: list[str] | None = None):
119-
if not schemas_seen:
119+
if schemas_seen is None:
120120
schemas_seen = []
121121
if isinstance(reference, dict):
122122
for k, v in copy.deepcopy(reference).items():

src/openapi_test_client/libraries/api/types.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
make_dataclass,
1515
)
1616
from functools import lru_cache
17-
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast
17+
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, TypeVar, cast
1818

1919
from common_libs.decorators import freeze_args
2020
from common_libs.hash import HashableDict
@@ -219,6 +219,14 @@ def to_pydantic(cls) -> type[PydanticModel]:
219219
)
220220

221221

222+
class DataclassModelField(NamedTuple):
223+
"""Dataclass model field"""
224+
225+
name: str
226+
type: Any
227+
default: Field | type[MISSING] = MISSING
228+
229+
222230
class EndpointModel(DataclassModel):
223231
content_type: str | None
224232
endpoint_func: EndpointFunc

0 commit comments

Comments
 (0)