@@ -138,32 +138,32 @@ def generate_imports_code_from_model(
138
138
:param exclude_nested_models: Skip imports for nested models (to avoid define imports for models in the same file)
139
139
"""
140
140
imports_code = ""
141
- module_and_name_pairs : list [ tuple [ str , str ]] = []
141
+ module_and_name_pairs = set ()
142
142
primitive_types = [int , float , str , bool ]
143
143
from openapi_test_client .libraries .api .api_client_generator import API_MODEL_CLASS_DIR_NAME
144
144
145
145
def generate_imports_code (obj_type : Any ):
146
146
if obj_type not in [* primitive_types , None , NoneType ] and not isinstance (obj_type , tuple (primitive_types )):
147
147
if typing_origin := get_origin (obj_type ):
148
148
if typing_origin is Annotated :
149
- module_and_name_pairs .append (("typing" , Annotated .__name__ ))
149
+ module_and_name_pairs .add (("typing" , Annotated .__name__ ))
150
150
[generate_imports_code (m ) for m in get_args (obj_type )]
151
151
elif typing_origin is Literal :
152
- module_and_name_pairs .append (("typing" , Literal .__name__ ))
152
+ module_and_name_pairs .add (("typing" , Literal .__name__ ))
153
153
elif typing_origin in [list , dict , tuple ]:
154
154
[generate_imports_code (m ) for m in [x for x in get_args (obj_type )]]
155
155
elif typing_origin in [UnionType , Union ]:
156
156
if param_type_util .is_optional_type (obj_type ):
157
157
# NOTE: We will use our alias version of typing.Optional for now
158
- # module_and_name_pairs.append (("typing", Optional.__name__))
159
- module_and_name_pairs .append ((types_module .__name__ , Optional .__name__ ))
158
+ # module_and_name_pairs.add (("typing", Optional.__name__))
159
+ module_and_name_pairs .add ((types_module .__name__ , Optional .__name__ ))
160
160
[generate_imports_code (m ) for m in get_args (obj_type )]
161
161
else :
162
162
raise NotImplementedError (f"Unsupported typing origin: { typing_origin } " )
163
163
elif has_param_model (obj_type ):
164
164
if not exclude_nested_models :
165
165
api_cls_module , model_file_name = api_class .__module__ .rsplit ("." , 1 )
166
- module_and_name_pairs .append (
166
+ module_and_name_pairs .add (
167
167
(
168
168
f"..{ API_MODEL_CLASS_DIR_NAME } .{ model_file_name } " ,
169
169
# Using the original field type here to detect list or not
@@ -175,7 +175,7 @@ def generate_imports_code(obj_type: Any):
175
175
name = obj_type .__name__
176
176
else :
177
177
name = type (obj_type ).__name__
178
- module_and_name_pairs .append ((obj_type .__module__ , name ))
178
+ module_and_name_pairs .add ((obj_type .__module__ , name ))
179
179
180
180
has_unset_field = False
181
181
for field_name , field_obj in model .__dataclass_fields__ .items ():
@@ -187,9 +187,8 @@ def generate_imports_code(obj_type: Any):
187
187
if has_unset_field :
188
188
imports_code = _add_unset_import_code (imports_code )
189
189
190
- if module_and_name_pairs :
191
- for module , name in set (module_and_name_pairs ):
192
- imports_code += f"from { module } import { name } \n "
190
+ for module , name in module_and_name_pairs :
191
+ imports_code += f"from { module } import { name } \n "
193
192
194
193
return imports_code
195
194
0 commit comments