Skip to content

Commit 5334e00

Browse files
Add more tests and fix arguments creation
1 parent f387fb7 commit 5334e00

File tree

16 files changed

+346
-188
lines changed

16 files changed

+346
-188
lines changed

ariadne_codegen/client_generators/client.py

+64-26
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
OPERATION_TYPE,
5151
OPTIONAL,
5252
PRINT_AST,
53+
SELECTION_NODE,
5354
SELECTION_SET_NODE,
5455
TUPLE,
5556
TYPING_MODULE,
@@ -457,23 +458,7 @@ def create_build_operation_ast_method(self):
457458
keywords=[
458459
generate_keyword(
459460
arg="selections",
460-
value=generate_list_comp(
461-
elt=generate_call(
462-
func=generate_attribute(
463-
value=generate_name(
464-
"field",
465-
),
466-
attr="to_ast",
467-
),
468-
args=[generate_name("idx")],
469-
),
470-
generators=[
471-
generate_comp(
472-
target="idx, field",
473-
iter_="enumerate(fields)",
474-
)
475-
],
476-
),
461+
value=generate_name("selections"),
477462
)
478463
],
479464
),
@@ -498,15 +483,10 @@ def create_build_operation_ast_method(self):
498483
args=[
499484
generate_arg("self"),
500485
generate_arg(
501-
"fields",
486+
"selections",
502487
annotation=generate_subscript(
503-
generate_name(TUPLE),
504-
generate_tuple(
505-
[
506-
generate_name("GraphQLField"),
507-
generate_name("..."),
508-
]
509-
),
488+
generate_name(LIST),
489+
generate_name(SELECTION_NODE),
510490
),
511491
),
512492
generate_arg(
@@ -531,6 +511,15 @@ def create_execute_custom_operation_method(self):
531511
variables_types_combined = generate_name("variables_types_combined")
532512
processed_variables_combined = generate_name("processed_variables_combined")
533513
method_body = [
514+
generate_assign(
515+
targets=["selections"],
516+
value=generate_call(
517+
func=generate_attribute(
518+
value=generate_name("self"), attr="_build_selection_set"
519+
),
520+
args=[generate_name("fields")],
521+
),
522+
),
534523
ast.Assign(
535524
targets=[
536525
generate_tuple(
@@ -561,7 +550,7 @@ def create_execute_custom_operation_method(self):
561550
value=generate_name("self"), attr="_build_operation_ast"
562551
),
563552
args=[
564-
generate_name("fields"),
553+
generate_name("selections"),
565554
generate_name("operation_type"),
566555
generate_name("operation_name"),
567556
generate_name("variable_definitions"),
@@ -627,6 +616,53 @@ def create_execute_custom_operation_method(self):
627616
),
628617
)
629618

619+
def create_build_selection_set(self):
620+
return generate_method_definition(
621+
name="_build_selection_set",
622+
arguments=generate_arguments(
623+
args=[
624+
generate_arg("self"),
625+
generate_arg(
626+
"fields",
627+
annotation=generate_subscript(
628+
generate_name("Tuple"),
629+
generate_tuple(
630+
[
631+
generate_name("GraphQLField"),
632+
generate_name("..."),
633+
]
634+
),
635+
),
636+
),
637+
]
638+
),
639+
body=[
640+
generate_return(
641+
value=generate_list_comp(
642+
elt=generate_call(
643+
func=generate_attribute(
644+
value=generate_name(
645+
"field",
646+
),
647+
attr="to_ast",
648+
),
649+
args=[generate_name("idx")],
650+
),
651+
generators=[
652+
generate_comp(
653+
target="idx, field",
654+
iter_="enumerate(fields)",
655+
)
656+
],
657+
),
658+
),
659+
],
660+
return_type=generate_subscript(
661+
generate_name(LIST),
662+
generate_name(SELECTION_NODE),
663+
),
664+
)
665+
630666
def add_execute_custom_operation_method(self):
631667
self._add_import(
632668
generate_import_from(
@@ -639,6 +675,7 @@ def add_execute_custom_operation_method(self):
639675
VARIABLE_DEFINITION_NODE,
640676
VARIABLE_NODE,
641677
NAMED_TYPE_NODE,
678+
SELECTION_NODE,
642679
],
643680
GRAPHQL_MODULE,
644681
)
@@ -654,6 +691,7 @@ def add_execute_custom_operation_method(self):
654691
self._class_def.body.append(self.create_combine_variables_method())
655692
self._class_def.body.append(self.create_build_variable_definitions_method())
656693
self._class_def.body.append(self.create_build_operation_ast_method())
694+
self._class_def.body.append(self.create_build_selection_set())
657695

658696
def create_custom_operation_method(self, name, operation_type):
659697
self._add_import(

ariadne_codegen/client_generators/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
OPERATION_DEFINITION_NODE = "OperationDefinitionNode"
2929
NAME_NODE = "NameNode"
3030
SELECTION_SET_NODE = "SelectionSetNode"
31+
SELECTION_NODE = "SelectionNode"
3132
PRINT_AST = "print_ast"
3233
OPERATION_TYPE = "OperationType"
3334
VARIABLE_DEFINITION_NODE = "VariableDefinitionNode"

ariadne_codegen/client_generators/custom_operation.py

+64-17
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
)
1414

1515
from ..codegen import (
16+
generate_ann_assign,
1617
generate_annotation_name,
1718
generate_arg,
1819
generate_arguments,
20+
generate_assign,
1921
generate_call,
2022
generate_class_def,
23+
generate_comp,
2124
generate_constant,
2225
generate_dict,
2326
generate_import_from,
@@ -26,6 +29,8 @@
2629
generate_module,
2730
generate_name,
2831
generate_return,
32+
generate_subscript,
33+
generate_tuple,
2934
)
3035
from ..exceptions import ParsingError
3136
from ..plugins.manager import PluginManager
@@ -36,6 +41,7 @@
3641
BASE_MODEL_FILE_PATH,
3742
CUSTOM_FIELDS_FILE_PATH,
3843
CUSTOM_FIELDS_TYPING_FILE_PATH,
44+
DICT,
3945
INPUT_SCALARS_MAP,
4046
OPTIONAL,
4147
TYPING_MODULE,
@@ -69,7 +75,7 @@ def __init__(
6975

7076
self._imports: List[ast.ImportFrom] = []
7177
self._type_imports: List[ast.ImportFrom] = []
72-
self._add_import(generate_import_from([OPTIONAL, ANY], TYPING_MODULE))
78+
self._add_import(generate_import_from([OPTIONAL, ANY, DICT], TYPING_MODULE))
7379

7480
self._class_def = generate_class_def(name=name, base_names=[])
7581

@@ -114,14 +120,64 @@ def _generate_method(
114120
operation_args,
115121
final_type,
116122
) -> ast.FunctionDef:
117-
method_arguments, return_arguments = self._generate_arguments(operation_args)
123+
(
124+
method_arguments,
125+
return_arguments_keys,
126+
return_arguments_values,
127+
) = self._generate_arguments(operation_args)
118128
return_type_name = self._get_return_type_and_from(final_type)
119129

120130
return generate_method_definition(
121131
name=str_to_snake_case(operation_name),
122132
arguments=method_arguments,
123133
return_type=generate_name(return_type_name),
124134
body=[
135+
generate_ann_assign(
136+
"arguments",
137+
generate_subscript(
138+
generate_name(DICT),
139+
generate_tuple(
140+
[
141+
generate_name("str"),
142+
generate_subscript(
143+
generate_name(DICT),
144+
generate_tuple(
145+
[
146+
generate_name("str"),
147+
generate_name(ANY),
148+
]
149+
),
150+
),
151+
]
152+
),
153+
),
154+
generate_dict(return_arguments_keys, return_arguments_values),
155+
),
156+
generate_assign(
157+
["cleared_arguments"],
158+
ast.DictComp(
159+
key=generate_name("key"),
160+
value=generate_name("value"),
161+
generators=[
162+
generate_comp(
163+
target="key, value",
164+
iter_="arguments.items()",
165+
ifs=[
166+
ast.Compare(
167+
left=generate_subscript(
168+
value=generate_name("value"),
169+
slice_=ast.Index(
170+
value=generate_constant("value"),
171+
), # type: ignore
172+
),
173+
ops=[ast.IsNot()],
174+
comparators=[generate_constant(None)],
175+
)
176+
],
177+
)
178+
],
179+
),
180+
),
125181
generate_return(
126182
value=generate_call(
127183
func=generate_name(return_type_name),
@@ -131,10 +187,13 @@ def _generate_method(
131187
arg="field_name",
132188
value=generate_constant(value=operation_name),
133189
),
134-
return_arguments,
190+
generate_keyword(
191+
arg="arguments",
192+
value=generate_name("cleared_arguments"),
193+
),
135194
],
136195
)
137-
)
196+
),
138197
],
139198
decorator_list=[generate_name("classmethod")],
140199
)
@@ -171,11 +230,8 @@ def _generate_arguments(self, operation_args):
171230
method_arguments = self._assemble_method_arguments(
172231
cls_arg, args, kw_only_args, kw_defaults
173232
)
174-
return_arguments = self._assemble_return_arguments(
175-
return_arguments_keys, return_arguments_values
176-
)
177233

178-
return method_arguments, return_arguments
234+
return method_arguments, return_arguments_keys, return_arguments_values
179235

180236
def _accumulate_method_arguments(
181237
self, args, kw_only_args, kw_defaults, name, annotation, is_required
@@ -231,15 +287,6 @@ def _assemble_method_arguments(self, cls_arg, args, kw_only_args, kw_defaults):
231287
kw_defaults=kw_defaults,
232288
)
233289

234-
def _assemble_return_arguments(self, keys, values):
235-
return generate_keyword(
236-
arg="arguments",
237-
value=generate_dict(
238-
keys=keys,
239-
values=values,
240-
),
241-
)
242-
243290
def _parse_graphql_type_name(
244291
self, type_, nullable: bool = True
245292
) -> Tuple[Union[ast.Name, ast.Subscript], Optional[str]]:

ariadne_codegen/client_generators/dependencies/base_operation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _build_selections(
6969
def _format_variable_name(
7070
self, idx: int, var_name: str, used_names: Set[str]
7171
) -> str:
72-
base_name = f"{idx}_{var_name}"
72+
base_name = f"{var_name}_{idx}"
7373
unique_name = base_name
7474
counter = 1
7575
while unique_name in used_names:

tests/main/clients/custom_query_builder/expected_client/base_operation.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929
) -> None:
3030
self._field_name = field_name
3131
self._variables = arguments or {}
32-
self._formatted_variables: Dict[str, Dict[str, Any]] = {}
32+
self.formatted_variables: Dict[str, Dict[str, Any]] = {}
3333
self._subfields: List[GraphQLField] = []
3434
self._alias: Optional[str] = None
3535
self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {}
@@ -69,7 +69,7 @@ def _build_selections(
6969
def _format_variable_name(
7070
self, idx: int, var_name: str, used_names: Set[str]
7171
) -> str:
72-
base_name = f"{idx}_{var_name}"
72+
base_name = f"{var_name}_{idx}"
7373
unique_name = base_name
7474
counter = 1
7575
while unique_name in used_names:
@@ -79,10 +79,10 @@ def _format_variable_name(
7979
return unique_name
8080

8181
def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None:
82-
self._formatted_variables = {}
82+
self.formatted_variables = {}
8383
for k, v in self._variables.items():
8484
unique_name = self._format_variable_name(idx, k, used_names)
85-
self._formatted_variables[unique_name] = {
85+
self.formatted_variables[unique_name] = {
8686
"name": k,
8787
"type": v["type"],
8888
"value": v["value"],
@@ -94,7 +94,7 @@ def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode:
9494
self._collect_all_variables(idx, used_names)
9595
formatted_args = [
9696
GraphQLArgument(v["name"], k).to_ast()
97-
for k, v in self._formatted_variables.items()
97+
for k, v in self.formatted_variables.items()
9898
]
9999
return FieldNode(
100100
name=NameNode(value=self._build_field_name()),
@@ -107,12 +107,12 @@ def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode:
107107
)
108108

109109
def get_formatted_variables(self) -> Dict[str, Dict[str, Any]]:
110-
formatted_variables = self._formatted_variables
110+
formatted_variables = self.formatted_variables
111111
for subfield in self._subfields:
112112
subfield.get_formatted_variables()
113-
self._formatted_variables.update(subfield._formatted_variables)
113+
self.formatted_variables.update(subfield.formatted_variables)
114114
for subfields in self._inline_fragments.values():
115115
for subfield in subfields:
116116
subfield.get_formatted_variables()
117-
self._formatted_variables.update(subfield._formatted_variables)
117+
self.formatted_variables.update(subfield.formatted_variables)
118118
return formatted_variables

0 commit comments

Comments
 (0)