@@ -258,11 +258,6 @@ def _update_imports(self, module: ast.Module):
258
258
return value. These will be moved and added to an `if TYPE_CHECKING`
259
259
block.
260
260
261
- **NOTE** If an `ast.ImportFrom` ends up without any names we must remove
262
- it completely otherwise formatting will not work (it would remove the
263
- empty `import from` but not format the rest of the code without running
264
- it twice).
265
-
266
261
We do this by storing all imports that we want to keep in an array, we
267
262
then drop all from the body and re-insert the ones to keep. Lastly we
268
263
import `TYPE_CHECKING` and add all our imports in the `if TYPE_CHECKING`
@@ -286,12 +281,27 @@ def _update_imports(self, module: ast.Module):
286
281
if len (return_types_not_used_as_input ) == 0 :
287
282
return None
288
283
289
- # We sadly have to iterate over all imports again and remove the imports
290
- # we will do conditionally.
291
- # It's very important that we get this right, if we keep any
292
- # `ImportFrom` that ends up without any names, the formatting will not
293
- # work! It will only remove the empty `import from` but not other unused
294
- # imports.
284
+ non_empty_imports = self ._update_existing_imports (
285
+ module , return_types_not_used_as_input
286
+ )
287
+ self ._add_forward_ref_imports (module , non_empty_imports )
288
+
289
+ def _update_existing_imports (
290
+ self , module : ast .Module , return_types_not_used_as_input : set [str ]
291
+ ) -> List [Union [ast .Import , ast .ImportFrom ]]:
292
+ """Update existing imports.
293
+
294
+ Remove all import or import from statements that would otherwise be
295
+ useless after moving them to forward refs.
296
+
297
+ It's very important that we get this right, if we keep any `ImportFrom`
298
+ that ends up without any names, the formatting will not work! It will
299
+ only remove the empty `import from` but not other unused imports.
300
+
301
+ :param module: The ast module to update
302
+ :param return_types_not_used_as_input: Set of return types not used as
303
+ input
304
+ """
295
305
non_empty_imports : List [Union [ast .Import , ast .ImportFrom ]] = []
296
306
last_import_at = 0
297
307
for i , node in enumerate (module .body ):
@@ -316,8 +326,18 @@ def _update_imports(self, module: ast.Module):
316
326
# We can now remove all imports and re-insert the ones that's not empty.
317
327
module .body = non_empty_imports + module .body [last_import_at + 1 :]
318
328
319
- # Create import to use for type checking. These will be put in an `if
320
- # TYPE_CHECKING` block.
329
+ return non_empty_imports
330
+
331
+ def _add_forward_ref_imports (
332
+ self ,
333
+ module : ast .Module ,
334
+ non_empty_imports : List [Union [ast .Import , ast .ImportFrom ]],
335
+ ) -> None :
336
+ """Add forward ref imports.
337
+
338
+ Add all the forward ref imports meaning all the types needed for type
339
+ checking under the `if TYPE_CHECKING` condition.
340
+ """
321
341
type_checking_imports = {}
322
342
for cls in self .input_and_return_types :
323
343
module_name = self .imported_classes [cls ]
0 commit comments