Skip to content

Commit 9fdc6eb

Browse files
Refactor in code structure
1 parent 7b63767 commit 9fdc6eb

18 files changed

+1150
-871
lines changed

Diff for: ariadne_codegen/client_generators/constants.py

+6
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,9 @@
125125
SCALARS_SERIALIZE_DICT_NAME = "SCALARS_SERIALIZE_FUNCTIONS"
126126

127127
OPERATION_TYPES = ("Query", "Mutation", "Subscription")
128+
129+
GRAPHQL_OBJECT_SUFFIX = "Fields"
130+
GRAPHQL_INTERFACE_SUFFIX = "Interface"
131+
GRAPHQL_FIELD_SUFFIX = "GraphQLField"
132+
GRAPHQL_UNION_SUFFIX = "Union"
133+
GRAPHQL_BASE_FIELD_CLASS = "GraphQLField"
+277
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
import ast
2+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
3+
4+
from graphql import (
5+
GraphQLEnumType,
6+
GraphQLInputObjectType,
7+
GraphQLInterfaceType,
8+
GraphQLNonNull,
9+
GraphQLObjectType,
10+
GraphQLScalarType,
11+
GraphQLUnionType,
12+
)
13+
14+
from ..codegen import (
15+
generate_ann_assign,
16+
generate_annotation_name,
17+
generate_arg,
18+
generate_arguments,
19+
generate_assign,
20+
generate_call,
21+
generate_comp,
22+
generate_constant,
23+
generate_dict,
24+
generate_import_from,
25+
generate_keyword,
26+
generate_name,
27+
generate_subscript,
28+
generate_tuple,
29+
)
30+
from ..exceptions import ParsingError
31+
from ..plugins.manager import PluginManager
32+
from ..utils import process_name
33+
from .constants import (
34+
ANY,
35+
BASE_MODEL_FILE_PATH,
36+
DICT,
37+
INPUT_SCALARS_MAP,
38+
UPLOAD_CLASS_NAME,
39+
)
40+
from .custom_generator_utils import get_final_type
41+
from .scalars import ScalarData, generate_scalar_imports
42+
43+
44+
class ArgumentGenerator:
45+
"""Generates method arguments for GraphQL fields."""
46+
47+
def __init__(
48+
self,
49+
custom_scalars: Dict[str, ScalarData],
50+
convert_to_snake_case: bool,
51+
plugin_manager: Optional[PluginManager] = None,
52+
) -> None:
53+
self.custom_scalars = custom_scalars
54+
self.convert_to_snake_case = convert_to_snake_case
55+
self.plugin_manager = plugin_manager
56+
self.imports: List[ast.ImportFrom] = []
57+
self._used_custom_scalars: List[str] = []
58+
59+
def _add_import(self, import_: Optional[ast.ImportFrom] = None) -> None:
60+
"""Adds an import statement to the list of imports."""
61+
if import_:
62+
if self.plugin_manager:
63+
import_ = self.plugin_manager.generate_client_import(import_)
64+
if import_.names:
65+
self.imports.append(import_)
66+
67+
def generate_arguments(
68+
self, operation_args: Dict[str, Any]
69+
) -> Tuple[ast.arguments, List[ast.expr], List[ast.expr]]:
70+
"""Generates method arguments from operation arguments."""
71+
cls_arg = generate_arg(name="cls")
72+
args: List[ast.arg] = []
73+
kw_only_args: List[ast.arg] = []
74+
kw_defaults: List[ast.expr] = []
75+
return_arguments_keys: List[ast.expr] = []
76+
return_arguments_values: List[ast.expr] = []
77+
78+
for arg_name, arg_value in operation_args.items():
79+
final_type = get_final_type(arg_value)
80+
is_required = isinstance(arg_value.type, GraphQLNonNull)
81+
name = process_name(
82+
arg_name, convert_to_snake_case=self.convert_to_snake_case
83+
)
84+
annotation, used_custom_scalar = self._parse_graphql_type_name(
85+
final_type, not is_required
86+
)
87+
88+
self._accumulate_method_arguments(
89+
args, kw_only_args, kw_defaults, name, annotation, is_required
90+
)
91+
self._accumulate_return_arguments(
92+
return_arguments_keys,
93+
return_arguments_values,
94+
arg_name,
95+
name,
96+
final_type,
97+
is_required,
98+
used_custom_scalar,
99+
)
100+
101+
method_arguments = self._assemble_method_arguments(
102+
cls_arg, args, kw_only_args, kw_defaults
103+
)
104+
return method_arguments, return_arguments_keys, return_arguments_values
105+
106+
def _accumulate_method_arguments(
107+
self,
108+
args: List[ast.arg],
109+
kw_only_args: List[ast.arg],
110+
kw_defaults: List[ast.expr],
111+
name: str,
112+
annotation: Optional[Union[ast.Name, ast.Subscript]],
113+
is_required: bool,
114+
) -> None:
115+
"""Accumulates method arguments."""
116+
if is_required:
117+
args.append(generate_arg(name=name, annotation=annotation))
118+
else:
119+
kw_only_args.append(generate_arg(name=name, annotation=annotation))
120+
kw_defaults.append(generate_constant(value=None))
121+
122+
def _accumulate_return_arguments(
123+
self,
124+
return_arguments_keys: List[ast.expr],
125+
return_arguments_values: List[ast.expr],
126+
arg_name: str,
127+
name: str,
128+
final_type: Union[GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType],
129+
is_required: bool,
130+
used_custom_scalar: Optional[str],
131+
) -> None:
132+
"""Accumulates return arguments."""
133+
constant_value = f"{final_type.name}!" if is_required else final_type.name
134+
return_arg_dict_value = self._generate_return_arg_value(
135+
name, used_custom_scalar
136+
)
137+
138+
return_arguments_keys.append(generate_constant(arg_name))
139+
return_arguments_values.append(
140+
generate_dict(
141+
keys=[generate_constant("type"), generate_constant("value")],
142+
values=[generate_constant(constant_value), return_arg_dict_value],
143+
)
144+
)
145+
146+
def _generate_return_arg_value(
147+
self, name: str, used_custom_scalar: Optional[str]
148+
) -> Union[ast.Call, ast.Name]:
149+
"""Generates the return argument value."""
150+
return_arg_dict_value: Union[ast.Call, ast.Name] = generate_name(name)
151+
152+
if used_custom_scalar:
153+
self._used_custom_scalars.append(used_custom_scalar)
154+
scalar_data = self.custom_scalars[used_custom_scalar]
155+
if scalar_data.serialize_name:
156+
return_arg_dict_value = generate_call(
157+
func=generate_name(scalar_data.serialize_name),
158+
args=[generate_name(name)],
159+
)
160+
161+
return return_arg_dict_value
162+
163+
def _assemble_method_arguments(
164+
self,
165+
cls_arg: ast.arg,
166+
args: List[ast.arg],
167+
kw_only_args: List[ast.arg],
168+
kw_defaults: List[ast.expr],
169+
) -> ast.arguments:
170+
"""Assembles method arguments."""
171+
return generate_arguments(
172+
args=[cls_arg, *args],
173+
kwonlyargs=kw_only_args,
174+
kw_defaults=kw_defaults, # type: ignore
175+
)
176+
177+
def _parse_graphql_type_name(
178+
self,
179+
type_: Union[GraphQLScalarType, GraphQLInputObjectType, GraphQLEnumType],
180+
nullable: bool = True,
181+
) -> Tuple[Union[ast.Name, ast.Subscript], Optional[str]]:
182+
"""Parses the GraphQL type name and determines if it is a custom scalar."""
183+
name = type_.name
184+
used_custom_scalar = None
185+
if isinstance(type_, GraphQLInputObjectType):
186+
self._add_import(
187+
generate_import_from(names=[name], from_="input_types", level=1)
188+
)
189+
elif isinstance(type_, GraphQLEnumType):
190+
self._add_import(generate_import_from(names=[name], level=1))
191+
elif isinstance(type_, GraphQLScalarType):
192+
if name not in self.custom_scalars:
193+
name = INPUT_SCALARS_MAP.get(name, ANY)
194+
if name == UPLOAD_CLASS_NAME:
195+
self._add_import(
196+
generate_import_from(
197+
names=[UPLOAD_CLASS_NAME],
198+
from_=BASE_MODEL_FILE_PATH.stem,
199+
level=1,
200+
)
201+
)
202+
else:
203+
used_custom_scalar = name
204+
name = self.custom_scalars[name].type_name
205+
self._used_custom_scalars.append(used_custom_scalar)
206+
else:
207+
raise ParsingError(f"Incorrect argument type {name}")
208+
return generate_annotation_name(name, nullable), used_custom_scalar
209+
210+
def add_custom_scalar_imports(self) -> None:
211+
"""Adds imports for custom scalars used in the schema."""
212+
for custom_scalar_name in self._used_custom_scalars:
213+
scalar_data = self.custom_scalars[custom_scalar_name]
214+
for import_ in generate_scalar_imports(scalar_data):
215+
self._add_import(import_)
216+
217+
def generate_clear_arguments_section(
218+
self,
219+
return_arguments_keys: List[ast.expr],
220+
return_arguments_values: List[ast.expr],
221+
) -> Tuple[List[ast.stmt], List[ast.keyword]]:
222+
arguments_body = [
223+
generate_ann_assign(
224+
"arguments",
225+
generate_subscript(
226+
generate_name(DICT),
227+
generate_tuple(
228+
[
229+
generate_name("str"),
230+
generate_subscript(
231+
generate_name(DICT),
232+
generate_tuple(
233+
[
234+
generate_name("str"),
235+
generate_name(ANY),
236+
]
237+
),
238+
),
239+
]
240+
),
241+
),
242+
generate_dict(
243+
return_arguments_keys,
244+
return_arguments_values, # type: ignore
245+
),
246+
),
247+
generate_assign(
248+
["cleared_arguments"],
249+
ast.DictComp(
250+
key=generate_name("key"),
251+
value=generate_name("value"),
252+
generators=[
253+
generate_comp(
254+
target="key, value",
255+
iter_="arguments.items()",
256+
ifs=cast(
257+
List[ast.expr],
258+
[
259+
ast.Compare(
260+
left=generate_subscript(
261+
value=generate_name("value"),
262+
slice_=generate_constant("value"),
263+
),
264+
ops=[ast.IsNot()],
265+
comparators=[generate_constant(None)],
266+
)
267+
],
268+
),
269+
)
270+
],
271+
),
272+
),
273+
]
274+
arguments_keyword = [
275+
generate_keyword(arg="arguments", value=generate_name("cleared_arguments"))
276+
]
277+
return arguments_body, arguments_keyword

0 commit comments

Comments
 (0)