Skip to content

Commit d311712

Browse files
update ast usage in code for python 3.12 changes (#306)
update-typing-to-satisfy-mypy
1 parent a064d22 commit d311712

21 files changed

+264
-172
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# CHANGELOG
22

3+
## 0.14.1 (UNRELEASED)
4+
5+
- Changed code typing to satisfy MyPy 1.11.0 version
6+
7+
38
## 0.14.0 (2024-07-17)
49

510
- Added `ClientForwardRefsPlugin` to standard plugins.

ariadne_codegen/client_generators/client.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ def _generate_variables_assign(
850850
self, variable_names: Dict[str, str], arguments_dict: ast.Dict, lineno: int = 1
851851
) -> ast.AnnAssign:
852852
return generate_ann_assign(
853-
target=variable_names[self._variables_dict_variable],
853+
target=generate_name(variable_names[self._variables_dict_variable]),
854854
annotation=generate_subscript(
855855
generate_name(DICT),
856856
generate_tuple([generate_name("str"), generate_name("object")]),

ariadne_codegen/client_generators/custom_arguments.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def generate_clear_arguments_section(
221221
) -> Tuple[List[ast.stmt], List[ast.keyword]]:
222222
arguments_body = [
223223
generate_ann_assign(
224-
"arguments",
224+
generate_name("arguments"),
225225
generate_subscript(
226226
generate_name(DICT),
227227
generate_tuple(
@@ -240,8 +240,8 @@ def generate_clear_arguments_section(
240240
),
241241
),
242242
generate_dict(
243-
return_arguments_keys,
244-
return_arguments_values, # type: ignore
243+
return_arguments_keys, # type: ignore
244+
return_arguments_values,
245245
),
246246
),
247247
generate_assign(

ariadne_codegen/client_generators/custom_fields.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _generate_class_field(
219219
name, field_name, getattr(field, "args")
220220
)
221221
return generate_ann_assign(
222-
target=name,
222+
target=generate_name(name),
223223
annotation=generate_name(f'"{field_name}"'),
224224
value=generate_call(
225225
func=generate_name(field_name), args=[generate_constant(org_name)]

ariadne_codegen/client_generators/input_fields.py

+25-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import ast
2-
from typing import Dict, Optional, Tuple
2+
from typing import Dict, List, Optional, Tuple, cast
33

44
from graphql import (
55
BooleanValueNode,
@@ -142,15 +142,18 @@ def parse_input_const_value_node(
142142

143143
if isinstance(node, ListValueNode):
144144
list_ = generate_list(
145-
[
146-
parse_input_const_value_node(
147-
node=v,
148-
field_type=field_type,
149-
nested_object=nested_object,
150-
nested_list=True,
151-
)
152-
for v in node.values
153-
]
145+
cast(
146+
List[ast.expr],
147+
[
148+
parse_input_const_value_node(
149+
node=v,
150+
field_type=field_type,
151+
nested_object=nested_object,
152+
nested_list=True,
153+
)
154+
for v in node.values
155+
],
156+
)
154157
)
155158
if not nested_list:
156159
return generate_call(
@@ -166,15 +169,18 @@ def parse_input_const_value_node(
166169
if isinstance(node, ObjectValueNode):
167170
dict_ = generate_dict(
168171
keys=[generate_constant(f.name.value) for f in node.fields],
169-
values=[
170-
parse_input_const_value_node(
171-
node=f.value,
172-
field_type=field_type,
173-
nested_object=True,
174-
nested_list=True,
175-
)
176-
for f in node.fields
177-
],
172+
values=cast(
173+
List[ast.expr],
174+
[
175+
parse_input_const_value_node(
176+
node=f.value,
177+
field_type=field_type,
178+
nested_object=True,
179+
nested_list=True,
180+
)
181+
for f in node.fields
182+
],
183+
),
178184
)
179185
if not nested_object:
180186
return generate_call(

ariadne_codegen/client_generators/input_types.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
generate_keyword,
1919
generate_method_call,
2020
generate_module,
21+
generate_name,
2122
generate_pydantic_field,
2223
model_has_forward_refs,
2324
)
@@ -172,7 +173,7 @@ def _parse_input_definition(
172173
field.type, custom_scalars=self.custom_scalars
173174
)
174175
field_implementation = generate_ann_assign(
175-
target=name,
176+
target=generate_name(name),
176177
annotation=annotation,
177178
value=parse_input_field_default_value(
178179
node=field.ast_node, annotation=annotation, field_type=field_type

ariadne_codegen/client_generators/result_fields.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,9 @@ def parse_interface_type(
216216
)
217217
context.abstract_type = True
218218
if inline_fragments or fragments_on_subtypes:
219-
types = [generate_annotation_name('"' + class_name + type_.name + '"', False)]
219+
types: List[ast.expr] = [
220+
generate_annotation_name('"' + class_name + type_.name + '"', False)
221+
]
220222
context.related_classes.append(
221223
RelatedClassData(class_name=class_name + type_.name, type_name=type_.name)
222224
)
@@ -275,7 +277,7 @@ def parse_union_type(
275277
class_name: str,
276278
) -> Annotation:
277279
context.abstract_type = True
278-
sub_annotations = [
280+
sub_annotations: List[ast.expr] = [
279281
parse_operation_field_type(
280282
type_=subtype,
281283
context=context,

ariadne_codegen/client_generators/result_types.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
generate_import_from,
3939
generate_method_call,
4040
generate_module,
41+
generate_name,
4142
generate_pass,
4243
generate_pydantic_field,
4344
model_has_forward_refs,
@@ -264,7 +265,7 @@ def _parse_type_definition(
264265
)
265266

266267
field_implementation = generate_ann_assign(
267-
target=name,
268+
target=generate_name(name),
268269
annotation=annotation,
269270
lineno=lineno,
270271
value=default_value,

ariadne_codegen/codegen.py

+59-42
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ast
2-
from typing import Any, Dict, List, Optional, Union
2+
import sys
3+
from typing import Any, Dict, List, Optional, Union, cast
34

45
from graphql import (
56
GraphQLEnumType,
@@ -24,11 +25,6 @@
2425
from .exceptions import ParsingError
2526

2627

27-
def generate_import(names: List[str], level: int = 0) -> ast.Import:
28-
"""Generate import statement."""
29-
return ast.Import(names=[ast.alias(n) for n in names], level=level)
30-
31-
3228
def generate_import_from(
3329
names: List[str], from_: Optional[str] = None, level: int = 0
3430
) -> ast.ImportFrom:
@@ -94,17 +90,21 @@ def generate_async_method_definition(
9490
return_type: Union[ast.Name, ast.Subscript],
9591
body: Optional[List[ast.stmt]] = None,
9692
lineno: int = 1,
97-
decorator_list: Optional[List[ast.Name]] = None,
93+
decorator_list: Optional[List[ast.expr]] = None,
9894
) -> ast.AsyncFunctionDef:
9995
"""Generate async function."""
100-
return ast.AsyncFunctionDef(
101-
name=name,
102-
args=arguments,
103-
body=body if body else [ast.Pass()],
104-
decorator_list=decorator_list if decorator_list else [],
105-
returns=return_type,
106-
lineno=lineno,
107-
)
96+
params: Dict[str, Any] = {
97+
"name": name,
98+
"args": arguments,
99+
"body": body if body else [ast.Pass()],
100+
"decorator_list": decorator_list if decorator_list else [],
101+
"returns": return_type,
102+
"lineno": lineno,
103+
}
104+
if sys.version_info >= (3, 12):
105+
params["type_params"] = []
106+
107+
return ast.AsyncFunctionDef(**params)
108108

109109

110110
def generate_class_def(
@@ -113,14 +113,20 @@ def generate_class_def(
113113
body: Optional[List[ast.stmt]] = None,
114114
) -> ast.ClassDef:
115115
"""Generate class definition."""
116-
bases = [ast.Name(id=name) for name in base_names] if base_names else []
117-
return ast.ClassDef(
118-
name=name,
119-
bases=bases,
120-
keywords=[],
121-
body=body if body else [],
122-
decorator_list=[],
116+
bases = cast(
117+
List[ast.expr], [ast.Name(id=name) for name in base_names] if base_names else []
123118
)
119+
params: Dict[str, Any] = {
120+
"name": name,
121+
"bases": bases,
122+
"keywords": [],
123+
"body": body if body else [],
124+
"decorator_list": [],
125+
}
126+
if sys.version_info >= (3, 12):
127+
params["type_params"] = []
128+
129+
return ast.ClassDef(**params)
124130

125131

126132
def generate_name(name: str) -> ast.Name:
@@ -153,37 +159,39 @@ def generate_assign(
153159
) -> ast.Assign:
154160
"""Generate assign object."""
155161
return ast.Assign(
156-
targets=[ast.Name(t) for t in targets], value=value, lineno=lineno
162+
targets=[ast.Name(t) for t in targets],
163+
value=value, # type:ignore
164+
lineno=lineno,
157165
)
158166

159167

160168
def generate_ann_assign(
161-
target: Union[str, ast.expr],
169+
target: Union[ast.Name, ast.Attribute, ast.Subscript],
162170
annotation: Annotation,
163171
value: Optional[ast.expr] = None,
164172
lineno: int = 1,
165173
) -> ast.AnnAssign:
166174
"""Generate ann assign object."""
167175
return ast.AnnAssign(
168-
target=target if isinstance(target, ast.expr) else ast.Name(id=target),
176+
target=target,
169177
annotation=annotation,
170-
simple=1,
171178
value=value,
179+
simple=1,
172180
lineno=lineno,
173181
)
174182

175183

176184
def generate_union_annotation(
177-
types: List[Union[ast.Name, ast.Subscript]], nullable: bool = True
185+
types: List[ast.expr], nullable: bool = True
178186
) -> ast.Subscript:
179187
"""Generate union annotation."""
180188
result = ast.Subscript(value=ast.Name(id=UNION), slice=ast.Tuple(elts=types))
181189
return result if not nullable else generate_nullable_annotation(result)
182190

183191

184192
def generate_dict(
185-
keys: Optional[List[ast.expr]] = None,
186-
values: Optional[List[Optional[ast.expr]]] = None,
193+
keys: Optional[List[Optional[ast.expr]]] = None,
194+
values: Optional[List[ast.expr]] = None,
187195
) -> ast.Dict:
188196
"""Generate dict object."""
189197
return ast.Dict(keys=keys if keys else [], values=values if values else [])
@@ -201,7 +209,9 @@ def generate_call(
201209
) -> ast.Call:
202210
"""Generate call object."""
203211
return ast.Call(
204-
func=func, args=args if args else [], keywords=keywords if keywords else []
212+
func=func,
213+
args=args if args else [], # type:ignore
214+
keywords=keywords if keywords else [],
205215
)
206216

207217

@@ -240,7 +250,10 @@ def parse_field_type(
240250
return generate_annotation_name('"' + type_.name + '"', nullable)
241251

242252
if isinstance(type_, GraphQLUnionType):
243-
subtypes = [parse_field_type(subtype, False) for subtype in type_.types]
253+
subtypes = cast(
254+
List[ast.expr],
255+
[parse_field_type(subtype, False) for subtype in type_.types],
256+
)
244257
return generate_union_annotation(subtypes, nullable)
245258

246259
if isinstance(type_, GraphQLList):
@@ -255,7 +268,7 @@ def parse_field_type(
255268

256269

257270
def generate_method_call(
258-
object_name: str, method_name: str, args: Optional[List[Optional[ast.expr]]] = None
271+
object_name: str, method_name: str, args: Optional[List[ast.expr]] = None
259272
) -> ast.Call:
260273
"""Generate object`s method call."""
261274
return ast.Call(
@@ -287,7 +300,7 @@ def generate_trivial_lambda(name: str, argument_name: str) -> ast.Assign:
287300
)
288301

289302

290-
def generate_list(elements: List[Optional[ast.expr]]) -> ast.List:
303+
def generate_list(elements: List[ast.expr]) -> ast.List:
291304
"""Generate list object."""
292305
return ast.List(elts=elements)
293306

@@ -343,16 +356,20 @@ def generate_method_definition(
343356
return_type: Union[ast.Name, ast.Subscript],
344357
body: Optional[List[ast.stmt]] = None,
345358
lineno: int = 1,
346-
decorator_list: Optional[List[ast.Name]] = None,
359+
decorator_list: Optional[List[ast.expr]] = None,
347360
) -> ast.FunctionDef:
348-
return ast.FunctionDef(
349-
name=name,
350-
args=arguments,
351-
body=body if body else [ast.Pass()],
352-
decorator_list=decorator_list if decorator_list else [],
353-
returns=return_type,
354-
lineno=lineno,
355-
)
361+
params: Dict[str, Any] = {
362+
"name": name,
363+
"args": arguments,
364+
"body": body if body else [ast.Pass()],
365+
"decorator_list": decorator_list if decorator_list else [],
366+
"returns": return_type,
367+
"lineno": lineno,
368+
}
369+
if sys.version_info >= (3, 12):
370+
params["type_params"] = []
371+
372+
return ast.FunctionDef(**params)
356373

357374

358375
def generate_async_for(

0 commit comments

Comments
 (0)