Skip to content

Commit 5f35bde

Browse files
author
normanre
committed
Fixed list arg handling #321 / #347
1 parent 3d5605f commit 5f35bde

File tree

3 files changed

+32
-9
lines changed

3 files changed

+32
-9
lines changed

ariadne_codegen/client_generators/custom_arguments.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import enum
23
from typing import Any, Dict, List, Optional, Tuple, Union, cast
34

45
from graphql import (
@@ -9,6 +10,7 @@
910
GraphQLObjectType,
1011
GraphQLScalarType,
1112
GraphQLUnionType,
13+
GraphQLList,
1214
)
1315

1416
from ..codegen import (
@@ -26,6 +28,7 @@
2628
generate_name,
2729
generate_subscript,
2830
generate_tuple,
31+
generate_list_annotation,
2932
)
3033
from ..exceptions import ParsingError
3134
from ..plugins.manager import PluginManager
@@ -77,12 +80,13 @@ def generate_arguments(
7780

7881
for arg_name, arg_value in operation_args.items():
7982
final_type = get_final_type(arg_value)
83+
is_list = isinstance(arg_value.type, GraphQLList)
8084
is_required = isinstance(arg_value.type, GraphQLNonNull)
8185
name = process_name(
8286
arg_name, convert_to_snake_case=self.convert_to_snake_case
8387
)
8488
annotation, used_custom_scalar = self._parse_graphql_type_name(
85-
final_type, not is_required
89+
final_type, not is_required, is_list
8690
)
8791

8892
self._accumulate_method_arguments(
@@ -93,8 +97,7 @@ def generate_arguments(
9397
return_arguments_values,
9498
arg_name,
9599
name,
96-
final_type,
97-
is_required,
100+
arg_value.type,
98101
used_custom_scalar,
99102
)
100103

@@ -125,12 +128,12 @@ def _accumulate_return_arguments(
125128
return_arguments_values: List[ast.expr],
126129
arg_name: str,
127130
name: str,
128-
final_type: Union[GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType],
129-
is_required: bool,
131+
complete_type: Union[GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType, GraphQLNonNull, GraphQLList],
130132
used_custom_scalar: Optional[str],
131133
) -> None:
132134
"""Accumulates return arguments."""
133-
constant_value = f"{final_type.name}!" if is_required else final_type.name
135+
constant_value = self._generate_complete_type_name(complete_type)
136+
134137
return_arg_dict_value = self._generate_return_arg_value(
135138
name, used_custom_scalar
136139
)
@@ -143,6 +146,22 @@ def _accumulate_return_arguments(
143146
)
144147
)
145148

149+
def _generate_complete_type_name(
150+
self,
151+
complete_type: Union[GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType, GraphQLNonNull, GraphQLList]
152+
) -> str:
153+
if isinstance(complete_type, GraphQLNonNull):
154+
if hasattr(complete_type, "of_type"):
155+
return f"{self._generate_complete_type_name(complete_type.of_type)}!"
156+
else:
157+
return f"{self._generate_complete_type_name(complete_type.type)}"
158+
if isinstance(complete_type, GraphQLList):
159+
if hasattr(complete_type, "of_type"):
160+
return f"[{self._generate_complete_type_name(complete_type.of_type)}]"
161+
else:
162+
return f"[{self._generate_complete_type_name(complete_type.type)}]"
163+
return complete_type.name
164+
146165
def _generate_return_arg_value(
147166
self, name: str, used_custom_scalar: Optional[str]
148167
) -> Union[ast.Call, ast.Name]:
@@ -178,6 +197,7 @@ def _parse_graphql_type_name(
178197
self,
179198
type_: Union[GraphQLScalarType, GraphQLInputObjectType, GraphQLEnumType],
180199
nullable: bool = True,
200+
is_list: bool = False,
181201
) -> Tuple[Union[ast.Name, ast.Subscript], Optional[str]]:
182202
"""Parses the GraphQL type name and determines if it is a custom scalar."""
183203
name = type_.name
@@ -205,7 +225,8 @@ def _parse_graphql_type_name(
205225
self._used_custom_scalars.append(used_custom_scalar)
206226
else:
207227
raise ParsingError(f"Incorrect argument type {name}")
208-
return generate_annotation_name(name, nullable), used_custom_scalar
228+
return generate_annotation_name(name, nullable) if not is_list else generate_list_annotation(
229+
generate_annotation_name(name, nullable=False), nullable), used_custom_scalar
209230

210231
def add_custom_scalar_imports(self) -> None:
211232
"""Adds imports for custom scalars used in the schema."""

ariadne_codegen/client_generators/custom_fields.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
OPTIONAL,
4343
TYPING_MODULE,
4444
UNION,
45+
LIST,
4546
)
4647
from .custom_generator_utils import TypeCollector, get_final_type
4748
from .scalars import ScalarData
@@ -70,7 +71,7 @@ def __init__(
7071
]
7172
self._add_import(
7273
generate_import_from(
73-
[OPTIONAL, UNION, ANY, DICT],
74+
[OPTIONAL, UNION, ANY, DICT, LIST],
7475
TYPING_MODULE,
7576
)
7677
)

ariadne_codegen/client_generators/custom_operation.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
GRAPHQL_UNION_SUFFIX,
3636
OPTIONAL,
3737
TYPING_MODULE,
38+
LIST,
3839
)
3940
from .custom_generator_utils import get_final_type
4041
from .scalars import ScalarData
@@ -67,7 +68,7 @@ def __init__(
6768

6869
self._imports: List[ast.ImportFrom] = []
6970
self._type_imports: List[ast.ImportFrom] = []
70-
self._add_import(generate_import_from([OPTIONAL, ANY, DICT], TYPING_MODULE))
71+
self._add_import(generate_import_from([OPTIONAL, ANY, DICT, LIST], TYPING_MODULE))
7172
self.argument_generator = ArgumentGenerator(
7273
self.custom_scalars,
7374
self.convert_to_snake_case,

0 commit comments

Comments
 (0)