3
3
import inspect
4
4
import json
5
5
import re
6
+ from collections .abc import Mapping , Sequence
6
7
from copy import deepcopy
7
- from dataclasses import MISSING , Field , field , make_dataclass
8
+ from dataclasses import MISSING , field , make_dataclass
8
9
from typing import TYPE_CHECKING , Any , cast
9
10
10
11
from common_libs .logging import get_logger
11
12
12
13
from openapi_test_client .libraries .api .api_functions .utils import param_model as param_model_util
13
14
from openapi_test_client .libraries .api .api_functions .utils import param_type as param_type_util
14
- from openapi_test_client .libraries .api .types import EndpointModel , File , ParamDef , Unset
15
+ from openapi_test_client .libraries .api .types import (
16
+ DataclassModelField ,
17
+ EndpointModel ,
18
+ File ,
19
+ ParamDef ,
20
+ Unset ,
21
+ )
15
22
16
23
if TYPE_CHECKING :
17
24
from openapi_test_client .libraries .api import EndpointFunc
@@ -27,8 +34,8 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
27
34
:param api_spec: Create a model from the OpenAPI spec. Otherwise the model be created from the existing endpoint
28
35
function signatures
29
36
"""
30
- path_param_fields = []
31
- body_or_query_param_fields = []
37
+ path_param_fields : list [ DataclassModelField ] = []
38
+ body_or_query_param_fields : list [ DataclassModelField ] = []
32
39
model_name = f"{ type (endpoint_func ).__name__ .replace ('EndpointFunc' , EndpointModel .__name__ )} "
33
40
content_type = None
34
41
if api_spec :
@@ -48,11 +55,10 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
48
55
continue
49
56
elif param_obj .default == inspect .Parameter .empty :
50
57
# Positional arguments (path parameters)
51
- path_param_fields .append ((name , param_obj .annotation ))
58
+ path_param_fields .append (DataclassModelField (name , param_obj .annotation ))
52
59
else :
53
60
# keyword arguments (body/query parameters)
54
- param_field = (name , param_obj .annotation , field (default = Unset ))
55
- body_or_query_param_fields .append (param_field )
61
+ _add_body_or_query_param_field (body_or_query_param_fields , name , param_obj .annotation )
56
62
57
63
if hasattr (endpoint_func , "endpoint" ):
58
64
method = endpoint_func .endpoint .method
@@ -64,13 +70,13 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
64
70
# Some OpenAPI specs don't properly document path parameters at all, or path parameters could be documented
65
71
# as incorrect "in" like "query". We fix this by adding the missing path parameters, and remove them from
66
72
# body/query params if any
67
- path_param_fields = [(x , str ) for x in expected_path_params ]
73
+ path_param_fields = [DataclassModelField (x , str ) for x in expected_path_params ]
68
74
body_or_query_param_fields = [x for x in body_or_query_param_fields if x [0 ] not in expected_path_params ]
69
75
70
76
# Address the case where a path param name conflicts with body/query param name
71
- for i , (field_name , field_type ) in enumerate (path_param_fields ):
77
+ for i , (field_name , field_type , _ ) in enumerate (path_param_fields ):
72
78
if field_name in [x [0 ] for x in body_or_query_param_fields ]:
73
- path_param_fields [i ] = (f"{ field_name } _" , field_type )
79
+ path_param_fields [i ] = DataclassModelField (f"{ field_name } _" , field_type )
74
80
75
81
# Some OpenAPI specs define a parameter name using characters we can't use as a python variable name.
76
82
# We will use the cleaned name as the model field and annotate it as `Annotated[field_type, Alias(<original_val>)]`
@@ -83,7 +89,7 @@ def create_endpoint_model(endpoint_func: EndpointFunc, api_spec: dict[str, Any]
83
89
type [EndpointModel ],
84
90
make_dataclass (
85
91
model_name ,
86
- fields ,
92
+ fields , # type: ignore
87
93
bases = (EndpointModel ,),
88
94
namespace = {"content_type" : content_type , "endpoint_func" : endpoint_func },
89
95
kw_only = True ,
@@ -130,8 +136,8 @@ def generate_func_signature_in_str(model: type[EndpointModel]) -> str:
130
136
def _parse_parameter_objects (
131
137
method : str ,
132
138
parameter_objects : list [dict [str , Any ]],
133
- path_param_fields : list [tuple [ str , Any ] ],
134
- body_or_query_param_fields : list [tuple [ str , Any , Field ] ],
139
+ path_param_fields : list [DataclassModelField ],
140
+ body_or_query_param_fields : list [DataclassModelField ],
135
141
):
136
142
"""Parse parameter objects
137
143
@@ -148,12 +154,13 @@ def _parse_parameter_objects(
148
154
param_type_annotation = param_type_util .resolve_type_annotation (
149
155
param_name , param_def , _is_required = is_required
150
156
)
151
-
152
157
if param_location in ["header" , "cookies" ]:
153
158
# We currently don't support these
154
159
continue
155
160
elif param_location == "path" :
156
- path_param_fields .append ((param_name , param_type_annotation ))
161
+ if param_name not in [x [0 ] for x in path_param_fields ]:
162
+ # Handle duplicates. Some API specs incorrectly document duplicated parameters
163
+ path_param_fields .append (DataclassModelField (param_name , param_type_annotation ))
157
164
elif param_location == "query" :
158
165
if method .upper () != "GET" :
159
166
# Annotate query params for non GET endpoints
@@ -180,19 +187,14 @@ def _parse_parameter_objects(
180
187
method , parameter_objects , path_param_fields , body_or_query_param_fields
181
188
)
182
189
else :
183
- if param_name not in [x [0 ] for x in body_or_query_param_fields ]:
184
- body_or_query_param_fields .append (
185
- (
186
- param_name ,
187
- param_type_annotation ,
188
- field (default = Unset , metadata = param_obj ),
189
- )
190
- )
191
- else :
192
- if param_name not in [x [0 ] for x in body_or_query_param_fields ]:
193
- body_or_query_param_fields .append (
194
- (param_name , param_type_annotation , field (default = Unset , metadata = param_obj ))
190
+ _add_body_or_query_param_field (
191
+ body_or_query_param_fields , param_name , param_type_annotation , param_obj = param_obj
195
192
)
193
+
194
+ else :
195
+ _add_body_or_query_param_field (
196
+ body_or_query_param_fields , param_name , param_type_annotation , param_obj = param_obj
197
+ )
196
198
else :
197
199
raise NotImplementedError (f"Unsupported param 'in': { param_location } " )
198
200
except Exception :
@@ -205,7 +207,7 @@ def _parse_parameter_objects(
205
207
206
208
207
209
def _parse_request_body_object (
208
- request_body_obj : dict [str , Any ], body_or_query_param_fields : list [tuple [ str , Any , Field ] ]
210
+ request_body_obj : dict [str , Any ], body_or_query_param_fields : list [DataclassModelField ]
209
211
) -> str | None :
210
212
"""Parse request body object
211
213
@@ -250,7 +252,9 @@ def parse_schema_obj(obj: dict[str, Any]):
250
252
param_type = File
251
253
if not param_def .is_required :
252
254
param_type = param_type | None
253
- body_or_query_param_fields .append ((param_name , param_type , field (default = Unset )))
255
+ _add_body_or_query_param_field (
256
+ body_or_query_param_fields , param_name , param_type , param_obj = param_obj
257
+ )
254
258
else :
255
259
existing_param_names = [x [0 ] for x in body_or_query_param_fields ]
256
260
if param_name in existing_param_names :
@@ -259,16 +263,17 @@ def parse_schema_obj(obj: dict[str, Any]):
259
263
for _ , t , m in duplicated_param_fields :
260
264
param_type_annotations .append (t )
261
265
param_type_annotation = param_type_util .generate_union_type (param_type_annotations )
262
- merged_param_field = (
266
+ merged_param_field = DataclassModelField (
263
267
param_name ,
264
268
param_type_annotation ,
265
- field (default = Unset , metadata = param_obj ),
269
+ default = field (default = Unset , metadata = param_obj ),
266
270
)
267
271
body_or_query_param_fields [existing_param_names .index (param_name )] = merged_param_field
268
272
else :
269
273
param_type_annotation = param_type_util .resolve_type_annotation (param_name , param_def )
270
- param_field = (param_name , param_type_annotation , field (default = Unset , metadata = param_obj ))
271
- body_or_query_param_fields .append (param_field )
274
+ _add_body_or_query_param_field (
275
+ body_or_query_param_fields , param_name , param_type_annotation , param_obj = param_obj
276
+ )
272
277
except Exception :
273
278
logger .error (
274
279
"Encountered an error while processing the param object in 'requestBody':\n "
@@ -283,6 +288,18 @@ def parse_schema_obj(obj: dict[str, Any]):
283
288
return content_type
284
289
285
290
291
+ def _add_body_or_query_param_field (
292
+ param_fields : list [DataclassModelField ],
293
+ param_name : str ,
294
+ param_type_annotation : Any ,
295
+ param_obj : Mapping [str , Any ] | dict [str , Any ] | Sequence [dict [str , Any ]] | None = None ,
296
+ ):
297
+ if param_name not in [x [0 ] for x in param_fields ]:
298
+ param_fields .append (
299
+ DataclassModelField (param_name , param_type_annotation , default = field (default = Unset , metadata = param_obj ))
300
+ )
301
+
302
+
286
303
def _is_file_param (
287
304
content_type : str ,
288
305
param_def : ParamDef | ParamDef .ParamGroup | ParamDef .UnknownType ,
0 commit comments