Skip to content

Commit 4e06845

Browse files
committed
Split _update_imports
1 parent 5791619 commit 4e06845

File tree

1 file changed

+36
-16
lines changed

1 file changed

+36
-16
lines changed

ariadne_codegen/contrib/client_forward_refs.py

+36-16
Original file line numberDiff line numberDiff line change
@@ -251,18 +251,13 @@ def _get_class_from_call(self, call: ast.Call) -> Optional[ast.Name]:
251251

252252
return call.func.value
253253

254-
def _update_imports(self, module: ast.Module):
254+
def _update_imports(self, module: ast.Module) -> None:
255255
"""Update all imports.
256256
257257
Iterate over all imports and remove the aliases that we use as input or
258258
return value. These will be moved and added to an `if TYPE_CHECKING`
259259
block.
260260
261-
**NOTE** If an `ast.ImportFrom` ends up without any names we must remove
262-
it completely otherwise formatting will not work (it would remove the
263-
empty `import from` but not format the rest of the code without running
264-
it twice).
265-
266261
We do this by storing all imports that we want to keep in an array, we
267262
then drop all from the body and re-insert the ones to keep. Lastly we
268263
import `TYPE_CHECKING` and add all our imports in the `if TYPE_CHECKING`
@@ -286,12 +281,29 @@ def _update_imports(self, module: ast.Module):
286281
if len(return_types_not_used_as_input) == 0:
287282
return None
288283

289-
# We sadly have to iterate over all imports again and remove the imports
290-
# we will do conditionally.
291-
# It's very important that we get this right, if we keep any
292-
# `ImportFrom` that ends up without any names, the formatting will not
293-
# work! It will only remove the empty `import from` but not other unused
294-
# imports.
284+
non_empty_imports = self._update_existing_imports(
285+
module, return_types_not_used_as_input
286+
)
287+
self._add_forward_ref_imports(module, non_empty_imports)
288+
289+
return None
290+
291+
def _update_existing_imports(
292+
self, module: ast.Module, return_types_not_used_as_input: set[str]
293+
) -> List[Union[ast.Import, ast.ImportFrom]]:
294+
"""Update existing imports.
295+
296+
Remove all import or import from statements that would otherwise be
297+
useless after moving them to forward refs.
298+
299+
It's very important that we get this right, if we keep any `ImportFrom`
300+
that ends up without any names, the formatting will not work! It will
301+
only remove the empty `import from` but not other unused imports.
302+
303+
:param module: The ast module to update
304+
:param return_types_not_used_as_input: Set of return types not used as
305+
input
306+
"""
295307
non_empty_imports: List[Union[ast.Import, ast.ImportFrom]] = []
296308
last_import_at = 0
297309
for i, node in enumerate(module.body):
@@ -316,8 +328,18 @@ def _update_imports(self, module: ast.Module):
316328
# We can now remove all imports and re-insert the ones that's not empty.
317329
module.body = non_empty_imports + module.body[last_import_at + 1 :]
318330

319-
# Create import to use for type checking. These will be put in an `if
320-
# TYPE_CHECKING` block.
331+
return non_empty_imports
332+
333+
def _add_forward_ref_imports(
334+
self,
335+
module: ast.Module,
336+
non_empty_imports: List[Union[ast.Import, ast.ImportFrom]],
337+
) -> None:
338+
"""Add forward ref imports.
339+
340+
Add all the forward ref imports meaning all the types needed for type
341+
checking under the `if TYPE_CHECKING` condition.
342+
"""
321343
type_checking_imports = {}
322344
for cls in self.input_and_return_types:
323345
module_name = self.imported_classes[cls]
@@ -345,8 +367,6 @@ def _update_imports(self, module: ast.Module):
345367
),
346368
)
347369

348-
return None
349-
350370
def _update_name_to_constant(self, node: ast.expr) -> ast.expr:
351371
"""Update return types.
352372

0 commit comments

Comments
 (0)