11import ast
22
3- from graphql import OperationDefinitionNode , parse
3+ from graphql import GraphQLSchema , OperationDefinitionNode , build_schema , parse
44
55from ariadne_codegen .generators .arguments import ArgumentsGenerator
6- from ariadne_codegen .generators .constants import OPTIONAL
6+ from ariadne_codegen .generators .constants import ANY , OPTIONAL
77
88from ..utils import compare_ast
99
@@ -16,7 +16,19 @@ def _get_variable_definitions_from_query_str(query: str):
1616
1717
1818def test_generate_returns_arguments_with_correct_non_optional_names_and_annotations ():
19- generator = ArgumentsGenerator ()
19+ schema_str = """
20+ schema { query: Query }
21+ type Query { _skip: ID! }
22+
23+ input CustomInputType {
24+ fieldA: Int!
25+ fieldB: Float!
26+ fieldC: String!
27+ fieldD: Boolean!
28+ }
29+ """
30+ schema = build_schema (schema_str )
31+ generator = ArgumentsGenerator (schema = schema )
2032 query = (
2133 "query q($id: ID!, $name: String!, $amount: Int!, $val: Float!, "
2234 "$flag: Boolean!, $custom_input: CustomInputType!) {r}"
@@ -43,7 +55,12 @@ def test_generate_returns_arguments_with_correct_non_optional_names_and_annotati
4355
4456
4557def test_generate_returns_arguments_with_correct_optional_annotation ():
46- generator = ArgumentsGenerator ()
58+ schema_str = """
59+ schema { query: Query }
60+ type Query { _skip: ID! }
61+ """
62+ schema = build_schema (schema_str )
63+ generator = ArgumentsGenerator (schema = schema )
4764 query = "query q($id: ID) {r}"
4865 variable_definitions = _get_variable_definitions_from_query_str (query )
4966
@@ -62,7 +79,7 @@ def test_generate_returns_arguments_with_correct_optional_annotation():
6279
6380
6481def test_generate_returns_arguments_with_only_self_argument_without_annotation ():
65- generator = ArgumentsGenerator ()
82+ generator = ArgumentsGenerator (schema = GraphQLSchema () )
6683 query = "query q {r}"
6784 variable_definitions = _get_variable_definitions_from_query_str (query )
6885
@@ -76,18 +93,27 @@ def test_generate_returns_arguments_with_only_self_argument_without_annotation()
7693
7794
7895def test_generate_saves_used_non_scalar_types ():
79- generator = ArgumentsGenerator ()
96+ schema_str = """
97+ schema { query: Query }
98+ type Query { _skip: String! }
99+
100+ input Type1 { fieldA: Int! }
101+ input Type2 { fieldB: Int! }
102+ """
103+ schema = build_schema (schema_str )
104+ generator = ArgumentsGenerator (schema = schema )
80105 query = "query q($a1: String!, $a2: String, $a3: Type1!, $a4: Type2) {r}"
81106 variable_definitions = _get_variable_definitions_from_query_str (query )
82107
83108 generator .generate (variable_definitions )
84109
85- assert len (generator .used_types ) == 2
86- assert generator .used_types == ["Type1" , "Type2" ]
110+ used_inputs = generator .get_used_inputs ()
111+ assert len (used_inputs ) == 2
112+ assert used_inputs == ["Type1" , "Type2" ]
87113
88114
89115def test_generate_returns_arguments_and_dictionary_with_snake_case_names ():
90- generator = ArgumentsGenerator (convert_to_snake_case = True )
116+ generator = ArgumentsGenerator (schema = GraphQLSchema (), convert_to_snake_case = True )
91117 query = "query q($camelCase: String!, $snake_case: String!) {r}"
92118 variable_definitions = _get_variable_definitions_from_query_str (query )
93119
@@ -111,3 +137,34 @@ def test_generate_returns_arguments_and_dictionary_with_snake_case_names():
111137
112138 assert compare_ast (arguments , expected_arguments )
113139 assert compare_ast (arguments_dict , expected_arguments_dict )
140+
141+
142+ def test_generate_returns_arguments_with_used_custom_scalar ():
143+ schema_str = """
144+ schema { query: Query }
145+ type Query { _skip: String! }
146+ scalar CustomScalar
147+ """
148+ generator = ArgumentsGenerator (schema = build_schema (schema_str ))
149+ query_str = "query q($arg: CustomScalar!) {r}"
150+
151+ expected_arguments = ast .arguments (
152+ posonlyargs = [],
153+ args = [
154+ ast .arg (arg = "self" ),
155+ ast .arg (arg = "arg" , annotation = ast .Name (id = ANY )),
156+ ],
157+ kwonlyargs = [],
158+ kw_defaults = [],
159+ defaults = [],
160+ )
161+ expected_arguments_dict = ast .Dict (
162+ keys = [ast .Constant (value = "arg" )], values = [ast .Name (id = "arg" )]
163+ )
164+
165+ arguments , arguments_dict = generator .generate (
166+ _get_variable_definitions_from_query_str (query_str )
167+ )
168+
169+ assert compare_ast (arguments , expected_arguments )
170+ assert compare_ast (arguments_dict , expected_arguments_dict )
0 commit comments