@@ -251,18 +251,13 @@ def _get_class_from_call(self, call: ast.Call) -> Optional[ast.Name]:
251
251
252
252
return call .func .value
253
253
254
- def _update_imports (self , module : ast .Module ):
254
+ def _update_imports (self , module : ast .Module ) -> None :
255
255
"""Update all imports.
256
256
257
257
Iterate over all imports and remove the aliases that we use as input or
258
258
return value. These will be moved and added to an `if TYPE_CHECKING`
259
259
block.
260
260
261
- **NOTE** If an `ast.ImportFrom` ends up without any names we must remove
262
- it completely otherwise formatting will not work (it would remove the
263
- empty `import from` but not format the rest of the code without running
264
- it twice).
265
-
266
261
We do this by storing all imports that we want to keep in an array, we
267
262
then drop all from the body and re-insert the ones to keep. Lastly we
268
263
import `TYPE_CHECKING` and add all our imports in the `if TYPE_CHECKING`
@@ -286,12 +281,29 @@ def _update_imports(self, module: ast.Module):
286
281
if len (return_types_not_used_as_input ) == 0 :
287
282
return None
288
283
289
- # We sadly have to iterate over all imports again and remove the imports
290
- # we will do conditionally.
291
- # It's very important that we get this right, if we keep any
292
- # `ImportFrom` that ends up without any names, the formatting will not
293
- # work! It will only remove the empty `import from` but not other unused
294
- # imports.
284
+ non_empty_imports = self ._update_existing_imports (
285
+ module , return_types_not_used_as_input
286
+ )
287
+ self ._add_forward_ref_imports (module , non_empty_imports )
288
+
289
+ return None
290
+
291
+ def _update_existing_imports (
292
+ self , module : ast .Module , return_types_not_used_as_input : set [str ]
293
+ ) -> List [Union [ast .Import , ast .ImportFrom ]]:
294
+ """Update existing imports.
295
+
296
+ Remove all import or import from statements that would otherwise be
297
+ useless after moving them to forward refs.
298
+
299
+ It's very important that we get this right, if we keep any `ImportFrom`
300
+ that ends up without any names, the formatting will not work! It will
301
+ only remove the empty `import from` but not other unused imports.
302
+
303
+ :param module: The ast module to update
304
+ :param return_types_not_used_as_input: Set of return types not used as
305
+ input
306
+ """
295
307
non_empty_imports : List [Union [ast .Import , ast .ImportFrom ]] = []
296
308
last_import_at = 0
297
309
for i , node in enumerate (module .body ):
@@ -316,8 +328,18 @@ def _update_imports(self, module: ast.Module):
316
328
# We can now remove all imports and re-insert the ones that's not empty.
317
329
module .body = non_empty_imports + module .body [last_import_at + 1 :]
318
330
319
- # Create import to use for type checking. These will be put in an `if
320
- # TYPE_CHECKING` block.
331
+ return non_empty_imports
332
+
333
+ def _add_forward_ref_imports (
334
+ self ,
335
+ module : ast .Module ,
336
+ non_empty_imports : List [Union [ast .Import , ast .ImportFrom ]],
337
+ ) -> None :
338
+ """Add forward ref imports.
339
+
340
+ Add all the forward ref imports meaning all the types needed for type
341
+ checking under the `if TYPE_CHECKING` condition.
342
+ """
321
343
type_checking_imports = {}
322
344
for cls in self .input_and_return_types :
323
345
module_name = self .imported_classes [cls ]
@@ -345,8 +367,6 @@ def _update_imports(self, module: ast.Module):
345
367
),
346
368
)
347
369
348
- return None
349
-
350
370
def _update_name_to_constant (self , node : ast .expr ) -> ast .expr :
351
371
"""Update return types.
352
372
0 commit comments