Skip to content

Commit 52c89ae

Browse files
committed
Refactor endpoint model creation
1 parent 13c6619 commit 52c89ae

File tree

4 files changed

+84
-56
lines changed

4 files changed

+84
-56
lines changed

src/openapi_test_client/libraries/api/api_functions/utils/endpoint_model.py

+50-33
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,22 @@
33
import inspect
44
import json
55
import re
6+
from collections.abc import Mapping, Sequence
67
from copy import deepcopy
7-
from dataclasses import MISSING, Field, field, make_dataclass
8+
from dataclasses import MISSING, field, make_dataclass
89
from typing import TYPE_CHECKING, Any, cast
910

1011
from common_libs.logging import get_logger
1112

1213
from openapi_test_client.libraries.api.api_functions.utils import param_model as param_model_util
1314
from openapi_test_client.libraries.api.api_functions.utils import param_type as param_type_util
14-
from openapi_test_client.libraries.api.types import EndpointModel, File, ParamDef, Unset
15+
from openapi_test_client.libraries.api.types import (
16+
DataclassModelField,
17+
EndpointModel,
18+
File,
19+
ParamDef,
20+
Unset,
21+
)
1522

1623
if TYPE_CHECKING:
1724
from openapi_test_client.libraries.api import EndpointFunc
@@ -27,8 +34,8 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
2734
:param api_spec: Create a model from the OpenAPI spec. Otherwise the model be created from the existing endpoint
2835
function signatures
2936
"""
30-
path_param_fields = []
31-
body_or_query_param_fields = []
37+
path_param_fields: list[DataclassModelField] = []
38+
body_or_query_param_fields: list[DataclassModelField] = []
3239
model_name = f"{type(endpoint_func).__name__.replace('EndpointFunc', EndpointModel.__name__)}"
3340
content_type = None
3441
if api_spec:
@@ -48,11 +55,10 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
4855
continue
4956
elif param_obj.default == inspect.Parameter.empty:
5057
# Positional arguments (path parameters)
51-
path_param_fields.append((name, param_obj.annotation))
58+
path_param_fields.append(DataclassModelField(name, param_obj.annotation))
5259
else:
5360
# keyword arguments (body/query parameters)
54-
param_field = (name, param_obj.annotation, field(default=Unset))
55-
body_or_query_param_fields.append(param_field)
61+
_add_body_or_query_param_field(body_or_query_param_fields, name, param_obj.annotation)
5662

5763
if hasattr(endpoint_func, "endpoint"):
5864
method = endpoint_func.endpoint.method
@@ -64,13 +70,13 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
6470
# Some OpenAPI specs don't properly document path parameters at all, or path parameters could be documented
6571
# as incorrect "in" like "query". We fix this by adding the missing path parameters, and remove them from
6672
# body/query params if any
67-
path_param_fields = [(x, str) for x in expected_path_params]
73+
path_param_fields = [DataclassModelField(x, str) for x in expected_path_params]
6874
body_or_query_param_fields = [x for x in body_or_query_param_fields if x[0] not in expected_path_params]
6975

7076
# Address the case where a path param name conflicts with body/query param name
71-
for i, (field_name, field_type) in enumerate(path_param_fields):
77+
for i, (field_name, field_type, _) in enumerate(path_param_fields):
7278
if field_name in [x[0] for x in body_or_query_param_fields]:
73-
path_param_fields[i] = (f"{field_name}_", field_type)
79+
path_param_fields[i] = DataclassModelField(f"{field_name}_", field_type)
7480

7581
# Some OpenAPI specs define a parameter name using characters we can't use as a python variable name.
7682
# We will use the cleaned name as the model field and annotate it as `Annotated[field_type, Alias(<original_val>)]`
@@ -83,7 +89,7 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
8389
type[EndpointModel],
8490
make_dataclass(
8591
model_name,
86-
fields,
92+
fields, # type: ignore
8793
bases=(EndpointModel,),
8894
namespace={"content_type": content_type, "endpoint_func": endpoint_func},
8995
kw_only=True,
@@ -130,8 +136,8 @@ def generate_func_signature_in_str(model: type[EndpointModel]) -> str:
130136
def _parse_parameter_objects(
131137
method: str,
132138
parameter_objects: list[dict[str, Any]],
133-
path_param_fields: list[tuple[str, Any]],
134-
body_or_query_param_fields: list[tuple[str, Any, Field]],
139+
path_param_fields: list[DataclassModelField],
140+
body_or_query_param_fields: list[DataclassModelField],
135141
):
136142
"""Parse parameter objects
137143
@@ -148,12 +154,13 @@ def _parse_parameter_objects(
148154
param_type_annotation = param_type_util.resolve_type_annotation(
149155
param_name, param_def, _is_required=is_required
150156
)
151-
152157
if param_location in ["header", "cookies"]:
153158
# We currently don't support these
154159
continue
155160
elif param_location == "path":
156-
path_param_fields.append((param_name, param_type_annotation))
161+
if param_name not in [x[0] for x in path_param_fields]:
162+
# Handle duplicates. Some API specs incorrectly document duplicated parameters
163+
path_param_fields.append(DataclassModelField(param_name, param_type_annotation))
157164
elif param_location == "query":
158165
if method.upper() != "GET":
159166
# Annotate query params for non GET endpoints
@@ -180,19 +187,14 @@ def _parse_parameter_objects(
180187
method, parameter_objects, path_param_fields, body_or_query_param_fields
181188
)
182189
else:
183-
if param_name not in [x[0] for x in body_or_query_param_fields]:
184-
body_or_query_param_fields.append(
185-
(
186-
param_name,
187-
param_type_annotation,
188-
field(default=Unset, metadata=param_obj),
189-
)
190-
)
191-
else:
192-
if param_name not in [x[0] for x in body_or_query_param_fields]:
193-
body_or_query_param_fields.append(
194-
(param_name, param_type_annotation, field(default=Unset, metadata=param_obj))
190+
_add_body_or_query_param_field(
191+
body_or_query_param_fields, param_name, param_type_annotation, param_obj=param_obj
195192
)
193+
194+
else:
195+
_add_body_or_query_param_field(
196+
body_or_query_param_fields, param_name, param_type_annotation, param_obj=param_obj
197+
)
196198
else:
197199
raise NotImplementedError(f"Unsupported param 'in': {param_location}")
198200
except Exception:
@@ -205,7 +207,7 @@ def _parse_parameter_objects(
205207

206208

207209
def _parse_request_body_object(
208-
request_body_obj: dict[str, Any], body_or_query_param_fields: list[tuple[str, Any, Field]]
210+
request_body_obj: dict[str, Any], body_or_query_param_fields: list[DataclassModelField]
209211
) -> str | None:
210212
"""Parse request body object
211213
@@ -250,7 +252,9 @@ def parse_schema_obj(obj: dict[str, Any]):
250252
param_type = File
251253
if not param_def.is_required:
252254
param_type = param_type | None
253-
body_or_query_param_fields.append((param_name, param_type, field(default=Unset)))
255+
_add_body_or_query_param_field(
256+
body_or_query_param_fields, param_name, param_type, param_obj=param_obj
257+
)
254258
else:
255259
existing_param_names = [x[0] for x in body_or_query_param_fields]
256260
if param_name in existing_param_names:
@@ -259,16 +263,17 @@ def parse_schema_obj(obj: dict[str, Any]):
259263
for _, t, m in duplicated_param_fields:
260264
param_type_annotations.append(t)
261265
param_type_annotation = param_type_util.generate_union_type(param_type_annotations)
262-
merged_param_field = (
266+
merged_param_field = DataclassModelField(
263267
param_name,
264268
param_type_annotation,
265-
field(default=Unset, metadata=param_obj),
269+
default=field(default=Unset, metadata=param_obj),
266270
)
267271
body_or_query_param_fields[existing_param_names.index(param_name)] = merged_param_field
268272
else:
269273
param_type_annotation = param_type_util.resolve_type_annotation(param_name, param_def)
270-
param_field = (param_name, param_type_annotation, field(default=Unset, metadata=param_obj))
271-
body_or_query_param_fields.append(param_field)
274+
_add_body_or_query_param_field(
275+
body_or_query_param_fields, param_name, param_type_annotation, param_obj=param_obj
276+
)
272277
except Exception:
273278
logger.error(
274279
"Encountered an error while processing the param object in 'requestBody':\n"
@@ -283,6 +288,18 @@ def parse_schema_obj(obj: dict[str, Any]):
283288
return content_type
284289

285290

291+
def _add_body_or_query_param_field(
292+
param_fields: list[DataclassModelField],
293+
param_name: str,
294+
param_type_annotation: Any,
295+
param_obj: Mapping[str, Any] | dict[str, Any] | Sequence[dict[str, Any]] | None = None,
296+
):
297+
if param_name not in [x[0] for x in param_fields]:
298+
param_fields.append(
299+
DataclassModelField(param_name, param_type_annotation, default=field(default=Unset, metadata=param_obj))
300+
)
301+
302+
286303
def _is_file_param(
287304
content_type: str,
288305
param_def: ParamDef | ParamDef.ParamGroup | ParamDef.UnknownType,

src/openapi_test_client/libraries/api/api_functions/utils/param_model.py

+24-21
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from openapi_test_client.libraries.api.types import (
1818
Alias,
1919
DataclassModel,
20+
DataclassModelField,
2021
EndpointModel,
2122
File,
2223
ParamAnnotationType,
@@ -117,15 +118,22 @@ def create_model_from_param_def(
117118
return _merge_models([create_model_from_param_def(model_name, p) for p in param_def])
118119
else:
119120
fields = [
120-
(
121+
DataclassModelField(
121122
inner_param_name,
122123
param_type_util.resolve_type_annotation(inner_param_name, ParamDef.from_param_obj(inner_param_obj)),
123-
field(default=Unset, metadata=inner_param_obj),
124+
default=field(default=Unset, metadata=inner_param_obj),
124125
)
125126
for inner_param_name, inner_param_obj in param_def.get("properties", {}).items()
126127
]
127128
alias_illegal_model_field_names(fields)
128-
return cast(type[ParamModel], make_dataclass(model_name, fields, bases=(ParamModel,)))
129+
return cast(
130+
type[ParamModel],
131+
make_dataclass(
132+
model_name,
133+
fields, # type: ignore
134+
bases=(ParamModel,),
135+
),
136+
)
129137

130138

131139
def generate_imports_code_from_model(
@@ -286,10 +294,10 @@ def visit(model_name: str):
286294
return sorted(models, key=lambda x: sorted_models_names.index(x.__name__))
287295

288296

289-
def alias_illegal_model_field_names(param_fields: list[tuple[str, Any] | tuple[str, Any, Field]]):
297+
def alias_illegal_model_field_names(model_fields: list[DataclassModelField]):
290298
"""Clean illegal model field name and annotate the field type with Alias class
291299
292-
:param param_fields: fields value to be passed to make_dataclass()
300+
:param model_fields: fields value to be passed to make_dataclass()
293301
"""
294302

295303
def make_alias(name: str, param_type: Any) -> str:
@@ -330,22 +338,17 @@ def make_alias(name: str, param_type: Any) -> str:
330338
name += "_"
331339
return name
332340

333-
if param_fields:
334-
for i, param_field in enumerate(param_fields):
335-
if len(param_field) == 2:
336-
# path parameters
337-
field_name, field_type = param_field
338-
field_obj = object
339-
else:
340-
# body or query parameters
341-
field_name, field_type, field_obj = param_field
342-
343-
if (alias_name := make_alias(field_name, field_type)) != field_name:
344-
if isinstance(field_obj, Field) and field_obj.metadata:
345-
logger.warning(f"Converted parameter name '{field_name}' to '{alias_name}'")
346-
new_fields = [alias_name, param_type_util.generate_annotated_type(field_type, Alias(field_name))]
347-
new_fields.append(field_obj)
348-
param_fields[i] = tuple(new_fields)
341+
if model_fields:
342+
for i, model_field in enumerate(model_fields):
343+
if (alias_name := make_alias(model_field.name, model_field.type)) != model_field.name:
344+
if isinstance(model_field.default, Field) and model_field.default.metadata:
345+
logger.warning(f"Converted parameter name '{model_field.name}' to '{alias_name}'")
346+
new_fields = (
347+
alias_name,
348+
param_type_util.generate_annotated_type(model_field.type, Alias(model_field.name)),
349+
model_field.default,
350+
)
351+
model_fields[i] = DataclassModelField(*new_fields)
349352

350353

351354
def _merge_models(models: list[type[ParamModel]]) -> type[ParamModel]:

src/openapi_test_client/libraries/api/api_spec.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def has_reference(obj: Any) -> bool:
116116
return "'$ref':" in str(obj)
117117

118118
def resolve_recursive(reference: Any, schemas_seen: list[str] | None = None):
119-
if not schemas_seen:
119+
if schemas_seen is None:
120120
schemas_seen = []
121121
if isinstance(reference, dict):
122122
for k, v in copy.deepcopy(reference).items():

src/openapi_test_client/libraries/api/types.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
make_dataclass,
1515
)
1616
from functools import lru_cache
17-
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast
17+
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, TypeVar, cast
1818

1919
from common_libs.decorators import freeze_args
2020
from common_libs.hash import HashableDict
@@ -219,6 +219,14 @@ def to_pydantic(cls) -> type[PydanticModel]:
219219
)
220220

221221

222+
class DataclassModelField(NamedTuple):
223+
"""Dataclass model field"""
224+
225+
name: str
226+
type: Any
227+
default: Field | type[MISSING] = MISSING
228+
229+
222230
class EndpointModel(DataclassModel):
223231
content_type: str | None
224232
endpoint_func: EndpointFunc

0 commit comments

Comments
 (0)