Skip to content

Commit 0454cce

Browse files
committed
Split _update_imports
1 parent 5791619 commit 0454cce

File tree

1 file changed

+33
-13
lines changed

1 file changed

+33
-13
lines changed

ariadne_codegen/contrib/client_forward_refs.py

+33-13
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,6 @@ def _update_imports(self, module: ast.Module):
258258
return value. These will be moved and added to an `if TYPE_CHECKING`
259259
block.
260260
261-
**NOTE** If an `ast.ImportFrom` ends up without any names we must remove
262-
it completely otherwise formatting will not work (it would remove the
263-
empty `import from` but not format the rest of the code without running
264-
it twice).
265-
266261
We do this by storing all imports that we want to keep in an array, we
267262
then drop all from the body and re-insert the ones to keep. Lastly we
268263
import `TYPE_CHECKING` and add all our imports in the `if TYPE_CHECKING`
@@ -286,12 +281,27 @@ def _update_imports(self, module: ast.Module):
286281
if len(return_types_not_used_as_input) == 0:
287282
return None
288283

289-
# We sadly have to iterate over all imports again and remove the imports
290-
# we will do conditionally.
291-
# It's very important that we get this right, if we keep any
292-
# `ImportFrom` that ends up without any names, the formatting will not
293-
# work! It will only remove the empty `import from` but not other unused
294-
# imports.
284+
non_empty_imports = self._update_existing_imports(
285+
module, return_types_not_used_as_input
286+
)
287+
self._add_forward_ref_imports(module, non_empty_imports)
288+
289+
def _update_existing_imports(
290+
self, module: ast.Module, return_types_not_used_as_input: set[str]
291+
) -> List[Union[ast.Import, ast.ImportFrom]]:
292+
"""Update existing imports.
293+
294+
Remove all import or import from statements that would otherwise be
295+
useless after moving them to forward refs.
296+
297+
It's very important that we get this right, if we keep any `ImportFrom`
298+
that ends up without any names, the formatting will not work! It will
299+
only remove the empty `import from` but not other unused imports.
300+
301+
:param module: The ast module to update
302+
:param return_types_not_used_as_input: Set of return types not used as
303+
input
304+
"""
295305
non_empty_imports: List[Union[ast.Import, ast.ImportFrom]] = []
296306
last_import_at = 0
297307
for i, node in enumerate(module.body):
@@ -316,8 +326,18 @@ def _update_imports(self, module: ast.Module):
316326
# We can now remove all imports and re-insert the ones that's not empty.
317327
module.body = non_empty_imports + module.body[last_import_at + 1 :]
318328

319-
# Create import to use for type checking. These will be put in an `if
320-
# TYPE_CHECKING` block.
329+
return non_empty_imports
330+
331+
def _add_forward_ref_imports(
332+
self,
333+
module: ast.Module,
334+
non_empty_imports: List[Union[ast.Import, ast.ImportFrom]],
335+
) -> None:
336+
"""Add forward ref imports.
337+
338+
Add all the forward ref imports meaning all the types needed for type
339+
checking under the `if TYPE_CHECKING` condition.
340+
"""
321341
type_checking_imports = {}
322342
for cls in self.input_and_return_types:
323343
module_name = self.imported_classes[cls]

0 commit comments

Comments
 (0)