diff --git a/ariadne_codegen/client_generators/custom_arguments.py b/ariadne_codegen/client_generators/custom_arguments.py index 0841a02..8dcb1de 100644 --- a/ariadne_codegen/client_generators/custom_arguments.py +++ b/ariadne_codegen/client_generators/custom_arguments.py @@ -5,6 +5,7 @@ GraphQLEnumType, GraphQLInputObjectType, GraphQLInterfaceType, + GraphQLList, GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, @@ -23,6 +24,7 @@ generate_dict, generate_import_from, generate_keyword, + generate_list_annotation, generate_name, generate_subscript, generate_tuple, @@ -35,6 +37,8 @@ BASE_MODEL_FILE_PATH, DICT, INPUT_SCALARS_MAP, + LIST, + TYPING_MODULE, UPLOAD_CLASS_NAME, ) from .custom_generator_utils import get_final_type @@ -77,12 +81,13 @@ def generate_arguments( for arg_name, arg_value in operation_args.items(): final_type = get_final_type(arg_value) + is_list = isinstance(arg_value.type, GraphQLList) is_required = isinstance(arg_value.type, GraphQLNonNull) name = process_name( arg_name, convert_to_snake_case=self.convert_to_snake_case ) annotation, used_custom_scalar = self._parse_graphql_type_name( - final_type, not is_required + final_type, not is_required, is_list ) self._accumulate_method_arguments( @@ -93,8 +98,7 @@ def generate_arguments( return_arguments_values, arg_name, name, - final_type, - is_required, + arg_value.type, used_custom_scalar, ) @@ -125,12 +129,18 @@ def _accumulate_return_arguments( return_arguments_values: List[ast.expr], arg_name: str, name: str, - final_type: Union[GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType], - is_required: bool, + complete_type: Union[ + GraphQLObjectType, + GraphQLInterfaceType, + GraphQLUnionType, + GraphQLNonNull, + GraphQLList, + ], used_custom_scalar: Optional[str], ) -> None: """Accumulates return arguments.""" - constant_value = f"{final_type.name}!" if is_required else final_type.name + constant_value = self._generate_complete_type_name(complete_type) + return_arg_dict_value = self._generate_return_arg_value( name, used_custom_scalar ) @@ -143,6 +153,28 @@ def _accumulate_return_arguments( ) ) + def _generate_complete_type_name( + self, + complete_type: Union[ + GraphQLObjectType, + GraphQLInterfaceType, + GraphQLUnionType, + GraphQLNonNull, + GraphQLList, + ], + ) -> str: + if isinstance(complete_type, GraphQLNonNull): + if hasattr(complete_type, "of_type"): + return f"{self._generate_complete_type_name(complete_type.of_type)}!" + else: + return f"{self._generate_complete_type_name(complete_type.type)}" + if isinstance(complete_type, GraphQLList): + if hasattr(complete_type, "of_type"): + return f"[{self._generate_complete_type_name(complete_type.of_type)}]" + else: + return f"[{self._generate_complete_type_name(complete_type.type)}]" + return complete_type.name + def _generate_return_arg_value( self, name: str, used_custom_scalar: Optional[str] ) -> Union[ast.Call, ast.Name]: @@ -178,6 +210,7 @@ def _parse_graphql_type_name( self, type_: Union[GraphQLScalarType, GraphQLInputObjectType, GraphQLEnumType], nullable: bool = True, + is_list: bool = False, ) -> Tuple[Union[ast.Name, ast.Subscript], Optional[str]]: """Parses the GraphQL type name and determines if it is a custom scalar.""" name = type_.name @@ -205,7 +238,18 @@ def _parse_graphql_type_name( self._used_custom_scalars.append(used_custom_scalar) else: raise ParsingError(f"Incorrect argument type {name}") - return generate_annotation_name(name, nullable), used_custom_scalar + + if is_list: + self._add_import(generate_import_from(names=[LIST], from_=TYPING_MODULE)) + + return ( + generate_annotation_name(name, nullable) + if not is_list + else generate_list_annotation( + generate_annotation_name(name, nullable=False), nullable + ), + used_custom_scalar, + ) def add_custom_scalar_imports(self) -> None: """Adds imports for custom scalars used in the schema."""