3
3
import inspect
4
4
import json
5
5
import re
6
+ from collections .abc import Mapping , Sequence
6
7
from copy import deepcopy
7
- from dataclasses import MISSING , Field , field , make_dataclass
8
+ from dataclasses import MISSING , field , make_dataclass
8
9
from typing import TYPE_CHECKING , Any , cast
9
10
10
11
from common_libs .logging import get_logger
11
12
12
13
from openapi_test_client .libraries .api .api_functions .utils import param_model as param_model_util
13
14
from openapi_test_client .libraries .api .api_functions .utils import param_type as param_type_util
14
- from openapi_test_client .libraries .api .types import EndpointModel , File , ParamDef , Unset
15
+ from openapi_test_client .libraries .api .types import DataclassModelField , EndpointModel , File , ParamDef , Unset
15
16
16
17
if TYPE_CHECKING :
17
18
from openapi_test_client .libraries .api import EndpointFunc
@@ -27,8 +28,8 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
27
28
:param api_spec: Create a model from the OpenAPI spec. Otherwise the model be created from the existing endpoint
28
29
function signatures
29
30
"""
30
- path_param_fields = []
31
- body_or_query_param_fields = []
31
+ path_param_fields : list [ DataclassModelField ] = []
32
+ body_or_query_param_fields : list [ DataclassModelField ] = []
32
33
model_name = f"{ type (endpoint_func ).__name__ .replace ('EndpointFunc' , EndpointModel .__name__ )} "
33
34
content_type = None
34
35
if api_spec :
@@ -48,11 +49,10 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
48
49
continue
49
50
elif param_obj .default == inspect .Parameter .empty :
50
51
# Positional arguments (path parameters)
51
- path_param_fields .append ((name , param_obj .annotation ))
52
+ path_param_fields .append (DataclassModelField (name , param_obj .annotation ))
52
53
else :
53
54
# keyword arguments (body/query parameters)
54
- param_field = (name , param_obj .annotation , field (default = Unset ))
55
- body_or_query_param_fields .append (param_field )
55
+ _add_body_or_query_param_field (body_or_query_param_fields , name , param_obj .annotation )
56
56
57
57
if hasattr (endpoint_func , "endpoint" ):
58
58
method = endpoint_func .endpoint .method
@@ -64,13 +64,13 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
64
64
# Some OpenAPI specs don't properly document path parameters at all, or path parameters could be documented
65
65
# as incorrect "in" like "query". We fix this by adding the missing path parameters, and remove them from
66
66
# body/query params if any
67
- path_param_fields = [(x , str ) for x in expected_path_params ]
67
+ path_param_fields = [DataclassModelField (x , str ) for x in expected_path_params ]
68
68
body_or_query_param_fields = [x for x in body_or_query_param_fields if x [0 ] not in expected_path_params ]
69
69
70
70
# Address the case where a path param name conflicts with body/query param name
71
- for i , (field_name , field_type ) in enumerate (path_param_fields ):
71
+ for i , (field_name , field_type , _ ) in enumerate (path_param_fields ):
72
72
if field_name in [x [0 ] for x in body_or_query_param_fields ]:
73
- path_param_fields [i ] = (f"{ field_name } _" , field_type )
73
+ path_param_fields [i ] = DataclassModelField (f"{ field_name } _" , field_type )
74
74
75
75
# Some OpenAPI specs define a parameter name using characters we can't use as a python variable name.
76
76
# We will use the cleaned name as the model field and annotate it as `Annotated[field_type, Alias(<original_val>)]`
@@ -83,7 +83,7 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
83
83
type [EndpointModel ],
84
84
make_dataclass (
85
85
model_name ,
86
- fields ,
86
+ fields , # type: ignore
87
87
bases = (EndpointModel ,),
88
88
namespace = {"content_type" : content_type , "endpoint_func" : endpoint_func },
89
89
kw_only = True ,
@@ -130,8 +130,8 @@ def generate_func_signature_in_str(model: type[EndpointModel]) -> str:
130
130
def _parse_parameter_objects (
131
131
method : str ,
132
132
parameter_objects : list [dict [str , Any ]],
133
- path_param_fields : list [tuple [ str , Any ] ],
134
- body_or_query_param_fields : list [tuple [ str , Any , Field ] ],
133
+ path_param_fields : list [DataclassModelField ],
134
+ body_or_query_param_fields : list [DataclassModelField ],
135
135
):
136
136
"""Parse parameter objects
137
137
@@ -148,12 +148,13 @@ def _parse_parameter_objects(
148
148
param_type_annotation = param_type_util .resolve_type_annotation (
149
149
param_name , param_def , _is_required = is_required
150
150
)
151
-
152
151
if param_location in ["header" , "cookies" ]:
153
152
# We currently don't support these
154
153
continue
155
154
elif param_location == "path" :
156
- path_param_fields .append ((param_name , param_type_annotation ))
155
+ if param_name not in [x [0 ] for x in path_param_fields ]:
156
+ # Handle duplicates. Some API specs incorrectly document duplicated parameters
157
+ path_param_fields .append (DataclassModelField (param_name , param_type_annotation ))
157
158
elif param_location == "query" :
158
159
if method .upper () != "GET" :
159
160
# Annotate query params for non GET endpoints
@@ -180,19 +181,14 @@ def _parse_parameter_objects(
180
181
method , parameter_objects , path_param_fields , body_or_query_param_fields
181
182
)
182
183
else :
183
- if param_name not in [x [0 ] for x in body_or_query_param_fields ]:
184
- body_or_query_param_fields .append (
185
- (
186
- param_name ,
187
- param_type_annotation ,
188
- field (default = Unset , metadata = param_obj ),
189
- )
190
- )
191
- else :
192
- if param_name not in [x [0 ] for x in body_or_query_param_fields ]:
193
- body_or_query_param_fields .append (
194
- (param_name , param_type_annotation , field (default = Unset , metadata = param_obj ))
184
+ _add_body_or_query_param_field (
185
+ body_or_query_param_fields , param_name , param_type_annotation , param_obj = param_obj
195
186
)
187
+
188
+ else :
189
+ _add_body_or_query_param_field (
190
+ body_or_query_param_fields , param_name , param_type_annotation , param_obj = param_obj
191
+ )
196
192
else :
197
193
raise NotImplementedError (f"Unsupported param 'in': { param_location } " )
198
194
except Exception :
@@ -205,7 +201,7 @@ def _parse_parameter_objects(
205
201
206
202
207
203
def _parse_request_body_object (
208
- request_body_obj : dict [str , Any ], body_or_query_param_fields : list [tuple [ str , Any , Field ] ]
204
+ request_body_obj : dict [str , Any ], body_or_query_param_fields : list [DataclassModelField ]
209
205
) -> str | None :
210
206
"""Parse request body object
211
207
@@ -250,7 +246,9 @@ def parse_schema_obj(obj: dict[str, Any]):
250
246
param_type = File
251
247
if not param_def .is_required :
252
248
param_type = param_type | None
253
- body_or_query_param_fields .append ((param_name , param_type , field (default = Unset )))
249
+ _add_body_or_query_param_field (
250
+ body_or_query_param_fields , param_name , param_type , param_obj = param_obj
251
+ )
254
252
else :
255
253
existing_param_names = [x [0 ] for x in body_or_query_param_fields ]
256
254
if param_name in existing_param_names :
@@ -259,16 +257,17 @@ def parse_schema_obj(obj: dict[str, Any]):
259
257
for _ , t , m in duplicated_param_fields :
260
258
param_type_annotations .append (t )
261
259
param_type_annotation = param_type_util .generate_union_type (param_type_annotations )
262
- merged_param_field = (
260
+ merged_param_field = DataclassModelField (
263
261
param_name ,
264
262
param_type_annotation ,
265
- field (default = Unset , metadata = param_obj ),
263
+ default = field (default = Unset , metadata = param_obj ),
266
264
)
267
265
body_or_query_param_fields [existing_param_names .index (param_name )] = merged_param_field
268
266
else :
269
267
param_type_annotation = param_type_util .resolve_type_annotation (param_name , param_def )
270
- param_field = (param_name , param_type_annotation , field (default = Unset , metadata = param_obj ))
271
- body_or_query_param_fields .append (param_field )
268
+ _add_body_or_query_param_field (
269
+ body_or_query_param_fields , param_name , param_type_annotation , param_obj = param_obj
270
+ )
272
271
except Exception :
273
272
logger .error (
274
273
"Encountered an error while processing the param object in 'requestBody':\n "
@@ -283,6 +282,18 @@ def parse_schema_obj(obj: dict[str, Any]):
283
282
return content_type
284
283
285
284
285
+ def _add_body_or_query_param_field (
286
+ param_fields : list [DataclassModelField ],
287
+ param_name : str ,
288
+ param_type_annotation : Any ,
289
+ param_obj : Mapping [str , Any ] | dict [str , Any ] | Sequence [dict [str , Any ]] | None = None ,
290
+ ):
291
+ if param_name not in [x [0 ] for x in param_fields ]:
292
+ param_fields .append (
293
+ DataclassModelField (param_name , param_type_annotation , default = field (default = Unset , metadata = param_obj ))
294
+ )
295
+
296
+
286
297
def _is_file_param (
287
298
content_type : str ,
288
299
param_def : ParamDef | ParamDef .ParamGroup | ParamDef .UnknownType ,
0 commit comments