2
2
3
3
import pathlib
4
4
import re
5
+ from functools import cached_property
5
6
from typing import (
6
7
Any ,
7
8
Callable ,
42
43
ResponseObject ,
43
44
)
44
45
from datamodel_code_generator .types import DataType , DataTypeManager , StrictTypes
45
- from datamodel_code_generator .util import cached_property
46
46
from pydantic import BaseModel , ValidationInfo
47
47
48
48
RE_APPLICATION_JSON_PATTERN : Pattern [str ] = re .compile (r'^application/.*json$' )
@@ -93,16 +93,43 @@ class Argument(CachedPropertyModel):
93
93
type_hint : UsefulStr
94
94
default : Optional [UsefulStr ] = None
95
95
default_value : Optional [UsefulStr ] = None
96
+ field : Union [DataModelField , list [DataModelField ], None ] = None
96
97
required : bool
97
98
98
99
def __str__ (self ) -> str :
99
100
return self .argument
100
101
101
- @cached_property
102
+ @property
102
103
def argument (self ) -> str :
104
+ if self .field is None :
105
+ type_hint = self .type_hint
106
+ else :
107
+ type_hint = (
108
+ UsefulStr (self .field .type_hint )
109
+ if not isinstance (self .field , list )
110
+ else UsefulStr (
111
+ f"Union[{ ', ' .join (field .type_hint for field in self .field )} ]"
112
+ )
113
+ )
114
+ if self .default is None and self .required :
115
+ return f'{ self .name } : { type_hint } '
116
+ return f'{ self .name } : { type_hint } = { self .default } '
117
+
118
+ @property
119
+ def snakecase (self ) -> str :
120
+ if self .field is None :
121
+ type_hint = self .type_hint
122
+ else :
123
+ type_hint = (
124
+ UsefulStr (self .field .type_hint )
125
+ if not isinstance (self .field , list )
126
+ else UsefulStr (
127
+ f"Union[{ ', ' .join (field .type_hint for field in self .field )} ]"
128
+ )
129
+ )
103
130
if self .default is None and self .required :
104
- return f'{ self .name } : { self . type_hint } '
105
- return f'{ self .name } : { self . type_hint } = { self .default } '
131
+ return f'{ stringcase . snakecase ( self .name ) } : { type_hint } '
132
+ return f'{ stringcase . snakecase ( self .name ) } : { type_hint } = { self .default } '
106
133
107
134
108
135
class Operation (CachedPropertyModel ):
@@ -114,16 +141,39 @@ class Operation(CachedPropertyModel):
114
141
parameters : List [Dict [str , Any ]] = []
115
142
responses : Dict [UsefulStr , Any ] = {}
116
143
deprecated : bool = False
117
- imports : List [Import ] = []
118
144
security : Optional [List [Dict [str , List [str ]]]] = None
119
145
tags : Optional [List [str ]] = []
120
- arguments : str = ''
121
- snake_case_arguments : str = ''
122
146
request : Optional [Argument ] = None
123
147
response : str = ''
124
148
additional_responses : Dict [Union [str , int ], Dict [str , str ]] = {}
125
149
return_type : str = ''
126
150
callbacks : Dict [UsefulStr , List ["Operation" ]] = {}
151
+ arguments_list : List [Argument ] = []
152
+
153
+ @classmethod
154
+ def merge_arguments_with_union (cls , arguments : List [Argument ]) -> List [Argument ]:
155
+ grouped_arguments : DefaultDict [str , List [Argument ]] = DefaultDict (list )
156
+ for argument in arguments :
157
+ grouped_arguments [argument .name ].append (argument )
158
+
159
+ merged_arguments = []
160
+ for argument_list in grouped_arguments .values ():
161
+ if len (argument_list ) == 1 :
162
+ merged_arguments .append (argument_list [0 ])
163
+ else :
164
+ argument = argument_list [0 ]
165
+ fields = [
166
+ item
167
+ for arg in argument_list
168
+ if arg .field is not None
169
+ for item in (
170
+ arg .field if isinstance (arg .field , list ) else [arg .field ]
171
+ )
172
+ if item is not None
173
+ ]
174
+ argument .field = fields
175
+ merged_arguments .append (argument )
176
+ return merged_arguments
127
177
128
178
@cached_property
129
179
def type (self ) -> UsefulStr :
@@ -132,6 +182,27 @@ def type(self) -> UsefulStr:
132
182
"""
133
183
return self .method
134
184
185
+ @property
186
+ def arguments (self ) -> str :
187
+ sorted_arguments = Operation .merge_arguments_with_union (self .arguments_list )
188
+ return ", " .join (argument .argument for argument in sorted_arguments )
189
+
190
+ @property
191
+ def snake_case_arguments (self ) -> str :
192
+ sorted_arguments = Operation .merge_arguments_with_union (self .arguments_list )
193
+ return ", " .join (argument .snakecase for argument in sorted_arguments )
194
+
195
+ @property
196
+ def imports (self ) -> Imports :
197
+ imports = Imports ()
198
+ for argument in self .arguments_list :
199
+ if isinstance (argument .field , list ):
200
+ for field in argument .field :
201
+ imports .append (field .data_type .import_ )
202
+ elif argument .field :
203
+ imports .append (argument .field .data_type .import_ )
204
+ return imports
205
+
135
206
@cached_property
136
207
def root_path (self ) -> UsefulStr :
137
208
paths = self .path .split ("/" )
@@ -153,7 +224,7 @@ def function_name(self) -> str:
153
224
return stringcase .snakecase (name )
154
225
155
226
156
- @snooper_to_methods (max_variable_length = None )
227
+ @snooper_to_methods ()
157
228
class OpenAPIParser (OpenAPIModelParser ):
158
229
def __init__ (
159
230
self ,
@@ -166,7 +237,7 @@ def __init__(
166
237
base_class : Optional [str ] = None ,
167
238
custom_template_dir : Optional [pathlib .Path ] = None ,
168
239
extra_template_data : Optional [DefaultDict [str , Dict [str , Any ]]] = None ,
169
- target_python_version : PythonVersion = PythonVersion .PY_37 ,
240
+ target_python_version : PythonVersion = PythonVersion .PY_39 ,
170
241
dump_resolve_reference_action : Optional [Callable [[Iterable [str ]], str ]] = None ,
171
242
validation : bool = False ,
172
243
field_constraints : bool = False ,
@@ -314,6 +385,7 @@ def get_parameter_type(
314
385
default = default , # type: ignore
315
386
default_value = schema .default ,
316
387
required = field .required ,
388
+ field = field ,
317
389
)
318
390
319
391
def get_arguments (self , snake_case : bool , path : List [str ]) -> str :
@@ -347,6 +419,10 @@ def get_argument_list(self, snake_case: bool, path: List[str]) -> List[Argument]
347
419
or argument .type_hint .startswith ('Optional[' )
348
420
)
349
421
422
+ # check if there are duplicate argument.name
423
+ argument_names = [argument .name for argument in arguments ]
424
+ if len (argument_names ) != len (set (argument_names )):
425
+ self .imports_for_fastapi .append (Import (from_ = 'typing' , import_ = "Union" ))
350
426
return arguments
351
427
352
428
def parse_request_body (
@@ -466,10 +542,7 @@ def parse_operation(
466
542
resolved_path = self .model_resolver .resolve_ref (path )
467
543
path_name , method = path [- 2 :]
468
544
469
- self ._temporary_operation ['arguments' ] = self .get_arguments (
470
- snake_case = False , path = path
471
- )
472
- self ._temporary_operation ['snake_case_arguments' ] = self .get_arguments (
545
+ self ._temporary_operation ['arguments_list' ] = self .get_argument_list (
473
546
snake_case = True , path = path
474
547
)
475
548
main_operation = self ._temporary_operation
@@ -499,11 +572,8 @@ def parse_operation(
499
572
self ._temporary_operation = {'_parameters' : []}
500
573
cb_path = path + ['callbacks' , key , route , method ]
501
574
super ().parse_operation (cb_op , cb_path )
502
- self ._temporary_operation ['arguments' ] = self .get_arguments (
503
- snake_case = False , path = cb_path
504
- )
505
- self ._temporary_operation ['snake_case_arguments' ] = (
506
- self .get_arguments (snake_case = True , path = cb_path )
575
+ self ._temporary_operation ['arguments_list' ] = (
576
+ self .get_argument_list (snake_case = True , path = cb_path )
507
577
)
508
578
509
579
callbacks [key ].append (
@@ -527,13 +597,16 @@ def _collapse_root_model(self, data_type: DataType) -> DataType:
527
597
reference = data_type .reference
528
598
import functools
529
599
530
- if not (
531
- reference
532
- and (
533
- len (reference .children ) == 1
534
- or functools .reduce (lambda a , b : a == b , reference .children )
535
- )
536
- ):
600
+ try :
601
+ if not (
602
+ reference
603
+ and (
604
+ len (reference .children ) == 0
605
+ or functools .reduce (lambda a , b : a == b , reference .children )
606
+ )
607
+ ):
608
+ return data_type
609
+ except RecursionError :
537
610
return data_type
538
611
source = reference .source
539
612
if not isinstance (source , CustomRootType ):
0 commit comments