Skip to content

Commit f387fb7

Browse files
Add fix to argument duplication
1 parent 0d6f383 commit f387fb7

File tree

17 files changed

+569
-302
lines changed

17 files changed

+569
-302
lines changed

ariadne_codegen/client_generators/client.py

+26-22
Original file line numberDiff line numberDiff line change
@@ -223,30 +223,31 @@ def create_combine_variables_method(self):
223223
value=generate_dict(),
224224
),
225225
ast.For(
226-
target=generate_tuple(
227-
elts=[
228-
generate_name("idx"),
229-
generate_name("field"),
230-
],
231-
),
232-
iter=generate_call(
233-
func=generate_name("enumerate"),
234-
args=[generate_name("fields")],
235-
),
226+
target=generate_name("field"),
227+
iter=generate_name("fields"),
236228
body=[
229+
generate_assign(
230+
targets=["formatted_variables"],
231+
value=generate_call(
232+
func=generate_name("field.get_formatted_variables")
233+
),
234+
),
237235
generate_expr(
238236
value=generate_call(
239237
func=generate_attribute(
240238
value=generate_name("variables_types_combined"),
241239
attr="update",
242240
),
243241
args=[
244-
generate_call(
245-
func=generate_attribute(
246-
value=generate_name("field"),
247-
attr="get_variables_types",
248-
),
249-
args=[generate_name("idx")],
242+
ast.DictComp(
243+
key=generate_name("k"),
244+
value=generate_name('v["type"]'),
245+
generators=[
246+
generate_comp(
247+
target="k, v",
248+
iter_="formatted_variables.items()",
249+
)
250+
],
250251
)
251252
],
252253
)
@@ -258,12 +259,15 @@ def create_combine_variables_method(self):
258259
attr="update",
259260
),
260261
args=[
261-
generate_call(
262-
func=generate_attribute(
263-
value=generate_name("field"),
264-
attr="get_processed_variables",
265-
),
266-
args=[generate_name("idx")],
262+
ast.DictComp(
263+
key=generate_name("k"),
264+
value=generate_name('v["value"]'),
265+
generators=[
266+
generate_comp(
267+
target="k, v",
268+
iter_="formatted_variables.items()",
269+
)
270+
],
267271
)
268272
],
269273
)

ariadne_codegen/client_generators/custom_operation.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,16 @@ def _generate_arguments(self, operation_args):
147147
for arg_name, arg_value in operation_args.items():
148148
final_type = get_final_type(arg_value)
149149
is_required = isinstance(arg_value.type, GraphQLNonNull)
150-
name = self._process_argument_name(arg_name)
150+
name = process_name(
151+
arg_name,
152+
convert_to_snake_case=self.convert_to_snake_case,
153+
)
151154
annotation, used_custom_scalar = self._parse_graphql_type_name(
152155
final_type, not is_required
153156
)
154157

155158
self._accumulate_method_arguments(
156-
args, kw_only_args, kw_defaults, arg_name, annotation, is_required
159+
args, kw_only_args, kw_defaults, name, annotation, is_required
157160
)
158161
self._accumulate_return_arguments(
159162
return_arguments_keys,
@@ -174,16 +177,13 @@ def _generate_arguments(self, operation_args):
174177

175178
return method_arguments, return_arguments
176179

177-
def _process_argument_name(self, arg_name):
178-
return process_name(arg_name, convert_to_snake_case=self.convert_to_snake_case)
179-
180180
def _accumulate_method_arguments(
181-
self, args, kw_only_args, kw_defaults, arg_name, annotation, is_required
181+
self, args, kw_only_args, kw_defaults, name, annotation, is_required
182182
):
183183
if is_required:
184-
args.append(generate_arg(name=arg_name, annotation=annotation))
184+
args.append(generate_arg(name=name, annotation=annotation))
185185
else:
186-
kw_only_args.append(generate_arg(name=arg_name, annotation=annotation))
186+
kw_only_args.append(generate_arg(name=name, annotation=annotation))
187187
kw_defaults.append(generate_constant(value=None))
188188

189189
def _accumulate_return_arguments(
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Tuple, Union
1+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
22

33
from graphql import (
44
ArgumentNode,
@@ -12,34 +12,28 @@
1212

1313

1414
class GraphQLArgument:
15-
def __init__(self, argument_name: str):
15+
def __init__(self, argument_name: str, argument_value: Any):
1616
self._name = argument_name
17-
self._variable_name = argument_name
17+
self._value = argument_value
1818

19-
def to_ast(self, idx: int) -> ArgumentNode:
19+
def to_ast(self) -> ArgumentNode:
2020
return ArgumentNode(
2121
name=NameNode(value=self._name),
22-
value=VariableNode(name=NameNode(value=f"{idx}_{self._variable_name}")),
22+
value=VariableNode(name=NameNode(value=self._value)),
2323
)
2424

2525

2626
class GraphQLField:
2727
def __init__(
28-
self, field_name: str, arguments: Optional[Dict[str, Any]] = None
28+
self, field_name: str, arguments: Optional[Dict[str, Dict[str, Any]]] = None
2929
) -> None:
3030
self._field_name = field_name
3131
self._variables = arguments or {}
32-
self._arguments = [GraphQLArgument(k) for k in self._variables]
32+
self.formatted_variables: Dict[str, Dict[str, Any]] = {}
3333
self._subfields: List[GraphQLField] = []
3434
self._alias: Optional[str] = None
3535
self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {}
3636

37-
def get_variables_types(self, idx: int) -> Dict[str, Any]:
38-
return {f"{idx}_{k}": v["type"] for k, v in self._variables.items()}
39-
40-
def get_processed_variables(self, idx: int) -> Dict[str, Any]:
41-
return {f"{idx}_{k}": v["value"] for k, v in self._variables.items()}
42-
4337
def alias(self, alias: str) -> "GraphQLField":
4438
self._alias = alias
4539
return self
@@ -53,28 +47,72 @@ def add_inline_fragment(self, type_name: str, *subfields: "GraphQLField") -> Non
5347
def _build_field_name(self) -> str:
5448
return f"{self._alias}: {self._field_name}" if self._alias else self._field_name
5549

56-
def _build_selections(self, idx: int) -> List[Union[FieldNode, InlineFragmentNode]]:
50+
def _build_selections(
51+
self, idx: int, used_names: Set[str]
52+
) -> List[Union[FieldNode, InlineFragmentNode]]:
5753
selections: List[Union[FieldNode, InlineFragmentNode]] = [
58-
subfield.to_ast(idx) for subfield in self._subfields
54+
subfield.to_ast(idx, used_names) for subfield in self._subfields
5955
]
6056
for name, subfields in self._inline_fragments.items():
6157
selections.append(
6258
InlineFragmentNode(
6359
type_condition=NamedTypeNode(name=NameNode(value=name)),
6460
selection_set=SelectionSetNode(
65-
selections=[subfield.to_ast(idx) for subfield in subfields]
61+
selections=[
62+
subfield.to_ast(idx, used_names) for subfield in subfields
63+
]
6664
),
6765
)
6866
)
6967
return selections
7068

71-
def to_ast(self, idx: int) -> FieldNode:
69+
def _format_variable_name(
70+
self, idx: int, var_name: str, used_names: Set[str]
71+
) -> str:
72+
base_name = f"{idx}_{var_name}"
73+
unique_name = base_name
74+
counter = 1
75+
while unique_name in used_names:
76+
unique_name = f"{base_name}_{counter}"
77+
counter += 1
78+
used_names.add(unique_name)
79+
return unique_name
80+
81+
def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None:
82+
self.formatted_variables = {}
83+
for k, v in self._variables.items():
84+
unique_name = self._format_variable_name(idx, k, used_names)
85+
self.formatted_variables[unique_name] = {
86+
"name": k,
87+
"type": v["type"],
88+
"value": v["value"],
89+
}
90+
91+
def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode:
92+
if used_names is None:
93+
used_names = set()
94+
self._collect_all_variables(idx, used_names)
95+
formatted_args = [
96+
GraphQLArgument(v["name"], k).to_ast()
97+
for k, v in self.formatted_variables.items()
98+
]
7299
return FieldNode(
73100
name=NameNode(value=self._build_field_name()),
74-
arguments=[arg.to_ast(idx) for arg in self._arguments],
101+
arguments=formatted_args,
75102
selection_set=(
76-
SelectionSetNode(selections=self._build_selections(idx))
103+
SelectionSetNode(selections=self._build_selections(idx, used_names))
77104
if self._subfields or self._inline_fragments
78105
else None
79106
),
80107
)
108+
109+
def get_formatted_variables(self) -> Dict[str, Dict[str, Any]]:
110+
formatted_variables = self.formatted_variables
111+
for subfield in self._subfields:
112+
subfield.get_formatted_variables()
113+
self.formatted_variables.update(subfield.formatted_variables)
114+
for subfields in self._inline_fragments.values():
115+
for subfield in subfields:
116+
subfield.get_formatted_variables()
117+
self.formatted_variables.update(subfield.formatted_variables)
118+
return formatted_variables

ariadne_codegen/client_generators/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def _collect_dependent_types(self, graphql_type: GraphQLObjectType) -> None:
4646
stack.extend(subfield_type.types)
4747
for interface in current_type.interfaces:
4848
stack.append(interface)
49+
elif isinstance(current_type, GraphQLUnionType):
50+
stack.extend(current_type.types)
4951

5052

5153
def get_final_type(type_):
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Tuple, Union
1+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
22

33
from graphql import (
44
ArgumentNode,
@@ -12,34 +12,28 @@
1212

1313

1414
class GraphQLArgument:
15-
def __init__(self, argument_name: str):
15+
def __init__(self, argument_name: str, argument_value: Any):
1616
self._name = argument_name
17-
self._variable_name = argument_name
17+
self._value = argument_value
1818

19-
def to_ast(self, idx: int) -> ArgumentNode:
19+
def to_ast(self) -> ArgumentNode:
2020
return ArgumentNode(
2121
name=NameNode(value=self._name),
22-
value=VariableNode(name=NameNode(value=f"{idx}_{self._variable_name}")),
22+
value=VariableNode(name=NameNode(value=self._value)),
2323
)
2424

2525

2626
class GraphQLField:
2727
def __init__(
28-
self, field_name: str, arguments: Optional[Dict[str, Any]] = None
28+
self, field_name: str, arguments: Optional[Dict[str, Dict[str, Any]]] = None
2929
) -> None:
3030
self._field_name = field_name
3131
self._variables = arguments or {}
32-
self._arguments = [GraphQLArgument(k) for k in self._variables]
32+
self._formatted_variables: Dict[str, Dict[str, Any]] = {}
3333
self._subfields: List[GraphQLField] = []
3434
self._alias: Optional[str] = None
3535
self._inline_fragments: Dict[str, Tuple[GraphQLField, ...]] = {}
3636

37-
def get_variables_types(self, idx: int) -> Dict[str, Any]:
38-
return {f"{idx}_{k}": v["type"] for k, v in self._variables.items()}
39-
40-
def get_processed_variables(self, idx: int) -> Dict[str, Any]:
41-
return {f"{idx}_{k}": v["value"] for k, v in self._variables.items()}
42-
4337
def alias(self, alias: str) -> "GraphQLField":
4438
self._alias = alias
4539
return self
@@ -53,28 +47,72 @@ def add_inline_fragment(self, type_name: str, *subfields: "GraphQLField") -> Non
5347
def _build_field_name(self) -> str:
5448
return f"{self._alias}: {self._field_name}" if self._alias else self._field_name
5549

56-
def _build_selections(self, idx: int) -> List[Union[FieldNode, InlineFragmentNode]]:
50+
def _build_selections(
51+
self, idx: int, used_names: Set[str]
52+
) -> List[Union[FieldNode, InlineFragmentNode]]:
5753
selections: List[Union[FieldNode, InlineFragmentNode]] = [
58-
subfield.to_ast(idx) for subfield in self._subfields
54+
subfield.to_ast(idx, used_names) for subfield in self._subfields
5955
]
6056
for name, subfields in self._inline_fragments.items():
6157
selections.append(
6258
InlineFragmentNode(
6359
type_condition=NamedTypeNode(name=NameNode(value=name)),
6460
selection_set=SelectionSetNode(
65-
selections=[subfield.to_ast(idx) for subfield in subfields]
61+
selections=[
62+
subfield.to_ast(idx, used_names) for subfield in subfields
63+
]
6664
),
6765
)
6866
)
6967
return selections
7068

71-
def to_ast(self, idx: int) -> FieldNode:
69+
def _format_variable_name(
70+
self, idx: int, var_name: str, used_names: Set[str]
71+
) -> str:
72+
base_name = f"{idx}_{var_name}"
73+
unique_name = base_name
74+
counter = 1
75+
while unique_name in used_names:
76+
unique_name = f"{base_name}_{counter}"
77+
counter += 1
78+
used_names.add(unique_name)
79+
return unique_name
80+
81+
def _collect_all_variables(self, idx: int, used_names: Set[str]) -> None:
82+
self._formatted_variables = {}
83+
for k, v in self._variables.items():
84+
unique_name = self._format_variable_name(idx, k, used_names)
85+
self._formatted_variables[unique_name] = {
86+
"name": k,
87+
"type": v["type"],
88+
"value": v["value"],
89+
}
90+
91+
def to_ast(self, idx: int, used_names: Optional[Set[str]] = None) -> FieldNode:
92+
if used_names is None:
93+
used_names = set()
94+
self._collect_all_variables(idx, used_names)
95+
formatted_args = [
96+
GraphQLArgument(v["name"], k).to_ast()
97+
for k, v in self._formatted_variables.items()
98+
]
7299
return FieldNode(
73100
name=NameNode(value=self._build_field_name()),
74-
arguments=[arg.to_ast(idx) for arg in self._arguments],
101+
arguments=formatted_args,
75102
selection_set=(
76-
SelectionSetNode(selections=self._build_selections(idx))
103+
SelectionSetNode(selections=self._build_selections(idx, used_names))
77104
if self._subfields or self._inline_fragments
78105
else None
79106
),
80107
)
108+
109+
def get_formatted_variables(self) -> Dict[str, Dict[str, Any]]:
110+
formatted_variables = self._formatted_variables
111+
for subfield in self._subfields:
112+
subfield.get_formatted_variables()
113+
self._formatted_variables.update(subfield._formatted_variables)
114+
for subfields in self._inline_fragments.values():
115+
for subfield in subfields:
116+
subfield.get_formatted_variables()
117+
self._formatted_variables.update(subfield._formatted_variables)
118+
return formatted_variables

tests/main/clients/custom_query_builder/expected_client/client.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,14 @@ def _combine_variables(
4545
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
4646
variables_types_combined = {}
4747
processed_variables_combined = {}
48-
for idx, field in enumerate(fields):
49-
variables_types_combined.update(field.get_variables_types(idx))
50-
processed_variables_combined.update(field.get_processed_variables(idx))
48+
for field in fields:
49+
formatted_variables = field.get_formatted_variables()
50+
variables_types_combined.update(
51+
{k: v["type"] for k, v in formatted_variables.items()}
52+
)
53+
processed_variables_combined.update(
54+
{k: v["value"] for k, v in formatted_variables.items()}
55+
)
5156
return (variables_types_combined, processed_variables_combined)
5257

5358
def _build_variable_definitions(

0 commit comments

Comments
 (0)