Skip to content

Commit e25f4a1

Browse files
committed
fix: Fix ClientForwardRefsPlugin imports and add tests
- Add missing tests for `ClientForwardRefsPlugin` with and without combining with `ShorterResultsPlugin`. - Fix faulty imports - Store name and level separate to allow dots specified either on the module or via level. When importing just lookup what level and name to use. - Always use level 0 for `TYPE_CHECKING_MODULE`. Fixes #314
1 parent 11bfe35 commit e25f4a1

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

ariadne_codegen/contrib/client_forward_refs.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, schema: GraphQLSchema, config_dict: Dict) -> None:
4242
# Imported classes are classes imported from local imports. We keep a
4343
# map between name and module so we know how to import them in each
4444
# method.
45-
self.imported_classes: Dict[str, str] = {}
45+
self.imported_classes: Dict[str, tuple[int, str]] = {}
4646

4747
# Imported classes in each method definition.
4848
self.imported_in_method: Set[str] = set()
@@ -116,9 +116,8 @@ def _store_imported_classes(self, module_body: List[ast.stmt]):
116116
continue
117117

118118
for name in node.names:
119-
from_ = "." * node.level + node.module
120119
if isinstance(name, ast.alias):
121-
self.imported_classes[name.name] = from_
120+
self.imported_classes[name.name] = (node.level, node.module)
122121

123122
def _rewrite_input_args_to_constants(
124123
self, method_def: Union[ast.FunctionDef, ast.AsyncFunctionDef]
@@ -178,12 +177,14 @@ def _insert_import_statement_in_method(
178177
# We add the class to our set of imported in methods - these classes
179178
# don't need to be imported at all in the global scope.
180179
self.imported_in_method.add(import_class_name)
180+
181+
level, module_name = self.imported_classes[import_class_name]
181182
method_def.body.insert(
182183
0,
183184
ast.ImportFrom(
184-
module=self.imported_classes[import_class_name],
185+
module=module_name,
185186
names=[import_class],
186-
level=1,
187+
level=level,
187188
),
188189
)
189190

@@ -342,10 +343,12 @@ def _add_forward_ref_imports(
342343
"""
343344
type_checking_imports = {}
344345
for cls in self.input_and_return_types:
345-
module_name = self.imported_classes[cls]
346+
level, module_name = self.imported_classes[cls]
346347
if module_name not in type_checking_imports:
347348
type_checking_imports[module_name] = ast.ImportFrom(
348-
module=module_name, names=[], level=1
349+
module=module_name,
350+
names=[],
351+
level=level,
349352
)
350353

351354
type_checking_imports[module_name].names.append(ast.alias(cls))
@@ -364,7 +367,7 @@ def _add_forward_ref_imports(
364367
ast.ImportFrom(
365368
module=TYPE_CHECKING_MODULE,
366369
names=[ast.alias(TYPE_CHECKING_FLAG)],
367-
level=1,
370+
level=0,
368371
),
369372
)
370373

tests/main/test_main.py

+30
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,36 @@ def test_main_shows_version():
213213
"example_client",
214214
CLIENTS_PATH / "custom_sync_query_builder" / "expected_client",
215215
),
216+
(
217+
(
218+
CLIENTS_PATH / "client_forward_refs" / "pyproject.toml",
219+
(
220+
CLIENTS_PATH / "client_forward_refs" / "queries.graphql",
221+
CLIENTS_PATH / "client_forward_refs" / "schema.graphql",
222+
CLIENTS_PATH / "client_forward_refs" / "custom_scalars.py",
223+
),
224+
),
225+
"client_forward_refs",
226+
CLIENTS_PATH / "client_forward_refs" / "expected_client",
227+
),
228+
(
229+
(
230+
CLIENTS_PATH / "client_forward_refs_shorter_results" / "pyproject.toml",
231+
(
232+
CLIENTS_PATH
233+
/ "client_forward_refs_shorter_results"
234+
/ "queries.graphql",
235+
CLIENTS_PATH
236+
/ "client_forward_refs_shorter_results"
237+
/ "schema.graphql",
238+
CLIENTS_PATH
239+
/ "client_forward_refs_shorter_results"
240+
/ "custom_scalars.py",
241+
),
242+
),
243+
"client_forward_refs_shorter_results",
244+
CLIENTS_PATH / "client_forward_refs_shorter_results" / "expected_client",
245+
),
216246
],
217247
indirect=["project_dir"],
218248
)

0 commit comments

Comments
 (0)