Skip to content

Commit 794743b

Browse files
committed
Add plugin to reduce client imports
* Convert input types and return types to `ast.Constant` * Import `TYPE_CHECKING` and import types if set * Import required types inside each method * Add tests * Update CHANGELOG * Update README
1 parent c14dd92 commit 794743b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+2865
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## 0.14.0 (Unreleased)
44

5+
- Added `NoGlobalImportsPlugin` to standard plugins.
56
- Re-added `model_rebuild` calls for input types with forward references.
67

78

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ Ariadne Codegen ships with optional plugins importable from the `ariadne_codegen
9696

9797
- [`ariadne_codegen.contrib.extract_operations.ExtractOperationsPlugin`](ariadne_codegen/contrib/extract_operations.py) - This extracts query strings from generated client's methods into separate `operations.py` module. It also modifies the generated client to import these definitions. Generated module name can be customized by adding `operations_module_name="custom_name"` to the `[tool.ariadne-codegen.operations]` section in config. Eg.:
9898

99+
- [`ariadne_codegen.contrib.no_global_imports.NoGlobalImportsPlugin`](ariadne_codegen/contrib/no_global_imports.py) - This plugin processes generated client module and convert all input arguments and return types to strings. The types will be imported only for type checking.
100+
99101
```toml
100102
[tool.ariadne-codegen]
101103
...
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
"""
2+
Plugin to only import types when you call methods
3+
4+
This will massively reduce import times for larger projects since you only have
5+
to load the input types when loading the client.
6+
7+
All result types that's used to process the server response will only be
8+
imported when the method is called.
9+
"""
10+
11+
import ast
12+
from typing import Dict
13+
14+
from graphql import GraphQLSchema
15+
16+
from ariadne_codegen import Plugin
17+
18+
19+
class NoGlobalImportsPlugin(Plugin):
20+
"""Only import types when you call an endpoint needing it"""
21+
22+
def __init__(self, schema: GraphQLSchema, config_dict: Dict) -> None:
23+
"""Constructor"""
24+
# Types that should only be imported in a `TYPE_CHECKING` context. This
25+
# is all the types used as arguments to a method or as a return type,
26+
# i.e. for type checking.
27+
self.input_and_return_types: set[str] = set()
28+
29+
# Imported classes are classes imported from local imports. We keep a
30+
# map between name and module so we know how to import them in each
31+
# method.
32+
self.imported_classes: dict[str, str] = {}
33+
34+
# Imported classes in each method definition.
35+
self.imported_in_method: set[str] = set()
36+
37+
super().__init__(schema, config_dict)
38+
39+
def generate_client_module(self, module: ast.Module) -> ast.Module:
40+
"""
41+
Update the generated client.
42+
43+
This will parse all current imports to map them to a path. It will then
44+
traverse all methods and look for the actual return type. The return
45+
node will be converted to an `ast.Constant` if it's an `ast.Name` and
46+
the return type will be imported only under `if TYPE_CHECKING`
47+
conditions.
48+
49+
It will also move all imports of the types used to parse the response
50+
inside each method since that's the only place where they're used. The
51+
result will be that we end up with imports in the global scope only for
52+
types used as input types.
53+
54+
:param module: The ast for the module
55+
"""
56+
self._store_imported_classes(module.body)
57+
58+
# Find the actual client class so we can grab all input and output
59+
# types. We also ensure to manipulate the ast while we do this.
60+
client_class_def = next(
61+
filter(lambda o: isinstance(o, ast.ClassDef), module.body), None
62+
)
63+
if not client_class_def or not isinstance(client_class_def, ast.ClassDef):
64+
return super().generate_client_module(module)
65+
66+
for method_def in [
67+
m
68+
for m in client_class_def.body
69+
if isinstance(m, (ast.FunctionDef, ast.AsyncFunctionDef))
70+
]:
71+
method_def = self._rewrite_input_args_to_constants(method_def)
72+
73+
# If the method returns anything, update whatever it returns.
74+
if method_def.returns:
75+
method_def.returns = self._update_name_to_constant(method_def.returns)
76+
77+
self._insert_import_statement_in_method(method_def)
78+
79+
self._update_imports(module)
80+
81+
return super().generate_client_module(module)
82+
83+
def _store_imported_classes(self, module_body: list[ast.stmt]):
84+
"""Fetch and store imported classes.
85+
86+
Grab all imported classes with level 1 or starting with `.` because
87+
these are the ones generated by us. We store a map between the class and
88+
which module it was imported from so we can easily import it when
89+
needed. This can be in a `TYPE_CHECKING` condition or inside a method.
90+
91+
:param module_body: The body of an `ast.Module`
92+
"""
93+
for node in module_body:
94+
if not isinstance(node, ast.ImportFrom):
95+
continue
96+
97+
if node.module is None:
98+
continue
99+
100+
# We only care about local imports from our generated code.
101+
if node.level != 1 and not node.module.startswith("."):
102+
continue
103+
104+
for name in node.names:
105+
from_ = "." * node.level + node.module
106+
if isinstance(name, ast.alias):
107+
self.imported_classes[name.name] = from_
108+
109+
def _rewrite_input_args_to_constants(
110+
self, method_def: ast.FunctionDef | ast.AsyncFunctionDef
111+
) -> ast.FunctionDef | ast.AsyncFunctionDef:
112+
"""Rewrite the arguments to a method.
113+
114+
For any `ast.Name` that requires an import convert it to an
115+
`ast.Constant` instead. The actual class will be noted and imported
116+
in a `TYPE_CHECKING` context.
117+
118+
:param method_def: Method definition
119+
:returns: The same definition but updated
120+
"""
121+
if not isinstance(method_def, (ast.FunctionDef, ast.AsyncFunctionDef)):
122+
return method_def
123+
124+
for i, input_arg in enumerate(method_def.args.args):
125+
annotation = input_arg.annotation
126+
if isinstance(annotation, (ast.Name, ast.Subscript, ast.Tuple)):
127+
method_def.args.args[i].annotation = self._update_name_to_constant(
128+
annotation
129+
)
130+
131+
return method_def
132+
133+
def _insert_import_statement_in_method(
134+
self, method_def: ast.FunctionDef | ast.AsyncFunctionDef
135+
):
136+
"""Insert import statement in method.
137+
138+
Each method will eventually pass the returned value to a class we've
139+
generated. Since we only need it in the scope of the method ensure we
140+
add it at the top of the method only. It will be removed from the global
141+
scope.
142+
143+
:param method_def: The method definition to updated
144+
"""
145+
# Find the last statement in the body, the call to this class is
146+
# what we need to import first.
147+
return_stmt = method_def.body[-1]
148+
if isinstance(return_stmt, ast.Return):
149+
call = self._get_call_arg_from_return(return_stmt)
150+
elif isinstance(return_stmt, ast.AsyncFor):
151+
call = self._get_call_arg_from_async_for(return_stmt)
152+
else:
153+
return
154+
155+
if call is None:
156+
return
157+
158+
import_class = self._get_class_from_call(call)
159+
if import_class is None:
160+
return
161+
162+
import_class_id = import_class.id
163+
164+
# We add the class to our set of imported in methods - these classes
165+
# don't need to be imported at all in the global scope.
166+
self.imported_in_method.add(import_class.id)
167+
method_def.body.insert(
168+
0,
169+
ast.ImportFrom(
170+
module=self.imported_classes[import_class_id],
171+
names=[import_class],
172+
),
173+
)
174+
175+
def _get_call_arg_from_return(self, return_stmt: ast.Return) -> ast.Call | None:
176+
"""Get the class used in the return statement.
177+
178+
:param return_stmt: The statement used for return
179+
"""
180+
# If it's a call of the class like produced by
181+
# `ShorterResultsPlugin` we have an attribute.
182+
if isinstance(return_stmt.value, ast.Attribute) and isinstance(
183+
return_stmt.value.value, ast.Call
184+
):
185+
return return_stmt.value.value
186+
187+
# If not it's just a call statement to the generated class.
188+
if isinstance(return_stmt.value, ast.Call):
189+
return return_stmt.value
190+
191+
return None
192+
193+
def _get_call_arg_from_async_for(self, last_stmt: ast.AsyncFor) -> ast.Call | None:
194+
"""Get the class used in the yield expression.
195+
196+
:param last_stmt: The statement used in `ast.AsyncFor`
197+
"""
198+
if isinstance(last_stmt.body, list) and isinstance(last_stmt.body[0], ast.Expr):
199+
body = last_stmt.body[0]
200+
elif isinstance(last_stmt.body, ast.Expr):
201+
body = last_stmt.body
202+
else:
203+
return None
204+
205+
if not isinstance(body, ast.Expr):
206+
return None
207+
208+
if not isinstance(body.value, ast.Yield):
209+
return None
210+
211+
# If it's a call of the class like produced by
212+
# `ShorterResultsPlugin` we have an attribute.
213+
if isinstance(body.value.value, ast.Attribute) and isinstance(
214+
body.value.value.value, ast.Call
215+
):
216+
return body.value.value.value
217+
218+
# If not it's just a call statement to the generated class.
219+
if isinstance(body.value.value, ast.Call):
220+
return body.value.value
221+
222+
return None
223+
224+
def _get_class_from_call(self, call: ast.Call) -> ast.Name | None:
225+
"""Get the class from an `ast.Call`.
226+
227+
:param call: The `ast.Call` arg
228+
:returns: `ast.Name` or `None`
229+
"""
230+
if not isinstance(call.func, ast.Attribute):
231+
return None
232+
233+
if not isinstance(call.func.value, ast.Name):
234+
return None
235+
236+
return call.func.value
237+
238+
def _update_imports(self, module: ast.Module) -> ast.Name | None:
239+
"""Update all imports.
240+
241+
Iterate over all imports and remove the aliases that we use as input or
242+
return value. These will be moved and added to an `if TYPE_CHECKING`
243+
block.
244+
245+
**NOTE** If an `ast.ImportFrom` ends up without any names we must remove
246+
it completely otherwise formatting will not work (it would remove the
247+
empty `import from` but not format the rest of the code without running
248+
it twice).
249+
250+
We do this by storing all imports that we want to keep in an array, we
251+
then drop all from the body and re-insert the ones to keep. Lastly we
252+
import `TYPE_CHECKING` and add all our imports in the `if TYPE_CHECKING`
253+
block.
254+
255+
:param module: The ast for the whole module.
256+
"""
257+
# We now know all our input types and all our return types. The return
258+
# types that are _not_ used as import types should be in an `if
259+
# TYPE_CHECKING` import block.
260+
return_types_not_used_as_input = set(self.input_and_return_types)
261+
262+
# The ones we import in the method don't need to be imported at all -
263+
# unless that's the type we return. This behaviour can differ if you use
264+
# a plugin such as `ShorterResultsPlugin` that will import a type that
265+
# is different from the type returned.
266+
return_types_not_used_as_input.update(
267+
{k for k in self.imported_in_method if k not in self.input_and_return_types}
268+
)
269+
270+
if len(return_types_not_used_as_input) == 0:
271+
return None
272+
273+
# We sadly have to iterate over all imports again and remove the imports
274+
# we will do conditionally.
275+
# It's very important that we get this right, if we keep any
276+
# `ImportFrom` that ends up without any names, the formatting will not
277+
# work! It will only remove the empty `import from` but not other unused
278+
# imports.
279+
non_empty_imports: list[ast.Import | ast.ImportFrom] = []
280+
last_import_at = 0
281+
for i, node in enumerate(module.body):
282+
if isinstance(node, ast.Import):
283+
last_import_at = i
284+
non_empty_imports.append(node)
285+
286+
if not isinstance(node, ast.ImportFrom):
287+
continue
288+
289+
last_import_at = i
290+
reduced_names = []
291+
for name in node.names:
292+
if name.name not in return_types_not_used_as_input:
293+
reduced_names.append(name)
294+
295+
node.names = reduced_names
296+
297+
if len(reduced_names) > 0:
298+
non_empty_imports.append(node)
299+
300+
# We can now remove all imports and re-insert the ones that's not empty.
301+
module.body = non_empty_imports + module.body[last_import_at + 1 :]
302+
303+
# Create import to use for type checking. These will be put in an `if
304+
# TYPE_CHECKING` block.
305+
type_checking_imports = {}
306+
for cls in self.input_and_return_types:
307+
module_name = self.imported_classes[cls]
308+
if module_name not in type_checking_imports:
309+
type_checking_imports[module_name] = ast.ImportFrom(
310+
module=module_name, names=[]
311+
)
312+
313+
type_checking_imports[module_name].names.append(ast.alias(cls))
314+
315+
import_if_type_checking = ast.If(
316+
test=ast.Name(id="TYPE_CHECKING"),
317+
body=list(type_checking_imports.values()),
318+
orelse=[],
319+
)
320+
321+
module.body.insert(len(non_empty_imports), import_if_type_checking)
322+
323+
# Import `TYPE_CHECKING`.
324+
module.body.insert(
325+
len(non_empty_imports),
326+
ast.ImportFrom(
327+
module="typing",
328+
names=[ast.Name("TYPE_CHECKING")],
329+
),
330+
)
331+
332+
return None
333+
334+
def _update_name_to_constant(self, node: ast.expr) -> ast.expr:
335+
"""Update return types.
336+
337+
If the return type contains any type that resolves to an `ast.Name`,
338+
convert it to an `ast.Constant`. We only need the type for type checking
339+
and can avoid importing the type in the global scope unless needed.
340+
341+
:param node: The ast node used as return type
342+
:returns: A modified ast node
343+
"""
344+
if isinstance(node, ast.Name):
345+
if node.id in self.imported_classes:
346+
self.input_and_return_types.add(node.id)
347+
return ast.Constant(value=node.id)
348+
349+
if isinstance(node, ast.Subscript):
350+
node.slice = self._update_name_to_constant(node.slice)
351+
return node
352+
353+
if isinstance(node, ast.Tuple):
354+
for i, _ in enumerate(node.elts):
355+
node.elts[i] = self._update_name_to_constant(node.elts[i])
356+
357+
return node
358+
359+
return node
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
SimpleScalar = str
2+
3+
4+
class ComplexScalar:
5+
def __init__(self, value: str) -> None:
6+
self.value = value
7+
8+
9+
def parse_complex_scalar(value: str) -> ComplexScalar:
10+
return ComplexScalar(value)
11+
12+
13+
def serialize_complex_scalar(value: ComplexScalar) -> str:
14+
return value.value

0 commit comments

Comments
 (0)