diff --git a/ariadne_codegen/client_generators/custom_arguments.py b/ariadne_codegen/client_generators/custom_arguments.py index 0841a024..27a61495 100644 --- a/ariadne_codegen/client_generators/custom_arguments.py +++ b/ariadne_codegen/client_generators/custom_arguments.py @@ -5,6 +5,7 @@ GraphQLEnumType, GraphQLInputObjectType, GraphQLInterfaceType, + GraphQLList, GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, @@ -93,8 +94,7 @@ def generate_arguments( return_arguments_values, arg_name, name, - final_type, - is_required, + arg_value.type, used_custom_scalar, ) @@ -125,12 +125,11 @@ def _accumulate_return_arguments( return_arguments_values: List[ast.expr], arg_name: str, name: str, - final_type: Union[GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType], - is_required: bool, + complete_type: Union[GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType], used_custom_scalar: Optional[str], ) -> None: """Accumulates return arguments.""" - constant_value = f"{final_type.name}!" if is_required else final_type.name + constant_value = self._generate_complete_type_name(complete_type) return_arg_dict_value = self._generate_return_arg_value( name, used_custom_scalar ) @@ -143,6 +142,22 @@ def _accumulate_return_arguments( ) ) + def _generate_complete_type_name( + self, + complete_type: Union[GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType, GraphQLNonNull, GraphQLList] + ) -> str: + if isinstance(complete_type, GraphQLNonNull): + if hasattr(complete_type, "of_type"): + return f"{self._generate_complete_type_name(complete_type.of_type)}!" + else: + return f"{self._generate_complete_type_name(complete_type.type)}" + if isinstance(complete_type, GraphQLList): + if hasattr(complete_type, "of_type"): + return f"[{self._generate_complete_type_name(complete_type.of_type)}]" + else: + return f"[{self._generate_complete_type_name(complete_type.type)}]" + return complete_type.name + def _generate_return_arg_value( self, name: str, used_custom_scalar: Optional[str] ) -> Union[ast.Call, ast.Name]: