Skip to content

Commit c53c734

Browse files
authored
Merge pull request #75 from mirumee/fix_custom_scalar_as_argument_type
Fix custom scalar as argument type
2 parents 647c072 + 459a3a9 commit c53c734

File tree

4 files changed

+129
-56
lines changed

4 files changed

+129
-56
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# CHANGELOG
22

3+
## UNRELEASED
4+
5+
- Fixed incorrectly raised exception when using custom scalar as query argument type.
6+
7+
38
## 0.2.0 (2023-02-02)
49

510
- Added `remote_schema_url` and `remote_schema_headers` settings to support reading remote schemas.

ariadne_codegen/generators/arguments.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
from typing import List, Tuple, Union
33

44
from graphql import (
5+
GraphQLEnumType,
6+
GraphQLInputObjectType,
7+
GraphQLScalarType,
8+
GraphQLSchema,
59
ListTypeNode,
610
NamedTypeNode,
711
NonNullTypeNode,
@@ -19,14 +23,46 @@
1923
generate_list_annotation,
2024
generate_name,
2125
)
22-
from .constants import SIMPLE_TYPE_MAP
26+
from .constants import ANY, SIMPLE_TYPE_MAP
2327
from .utils import str_to_snake_case
2428

2529

2630
class ArgumentsGenerator:
27-
def __init__(self, convert_to_snake_case: bool = True) -> None:
31+
def __init__(
32+
self, schema: GraphQLSchema, convert_to_snake_case: bool = True
33+
) -> None:
34+
self.schema = schema
2835
self.convert_to_snake_case = convert_to_snake_case
2936
self.used_types: List[str] = []
37+
self._used_enums: List[str] = []
38+
self._used_inputs: List[str] = []
39+
40+
def generate(
41+
self, variable_definitions: Tuple[VariableDefinitionNode, ...]
42+
) -> Tuple[ast.arguments, ast.Dict]:
43+
"""Generate arguments from given variable definitions."""
44+
arguments = generate_arguments([generate_arg("self")])
45+
dict_ = generate_dict()
46+
for variable_definition in variable_definitions:
47+
org_name = variable_definition.variable.name.value
48+
name = self._process_name(org_name)
49+
annotation = self._parse_type_node(variable_definition.type)
50+
51+
arguments.args.append(generate_arg(name, annotation))
52+
dict_.keys.append(generate_constant(org_name))
53+
dict_.values.append(generate_name(name))
54+
return arguments, dict_
55+
56+
def get_used_enums(self) -> List[str]:
57+
return self._used_enums
58+
59+
def get_used_inputs(self) -> List[str]:
60+
return self._used_inputs
61+
62+
def _process_name(self, name: str) -> str:
63+
if self.convert_to_snake_case:
64+
return str_to_snake_case(name)
65+
return name
3066

3167
def _parse_type_node(
3268
self,
@@ -50,31 +86,17 @@ def _parse_named_type_node(
5086
self, node: NamedTypeNode, nullable: bool = True
5187
) -> Union[ast.Name, ast.Subscript]:
5288
name = node.name.value
89+
type_ = self.schema.type_map.get(name)
90+
if not type_:
91+
raise ParsingError(f"Argument type {name} not found in schema.")
5392

54-
if name in SIMPLE_TYPE_MAP:
55-
name = SIMPLE_TYPE_MAP[name]
93+
if isinstance(type_, GraphQLInputObjectType):
94+
self._used_inputs.append(name)
95+
elif isinstance(type_, GraphQLEnumType):
96+
self._used_enums.append(name)
97+
elif isinstance(type_, GraphQLScalarType):
98+
name = SIMPLE_TYPE_MAP.get(name, ANY)
5699
else:
57-
self.used_types.append(name)
100+
raise ParsingError(f"Incorrect argument type {name}")
58101

59102
return generate_annotation_name(name, nullable)
60-
61-
def _process_name(self, name: str) -> str:
62-
if self.convert_to_snake_case:
63-
return str_to_snake_case(name)
64-
return name
65-
66-
def generate(
67-
self, variable_definitions: Tuple[VariableDefinitionNode, ...]
68-
) -> Tuple[ast.arguments, ast.Dict]:
69-
"""Generate arguments from given variable definitions."""
70-
arguments = generate_arguments([generate_arg("self")])
71-
dict_ = generate_dict()
72-
for variable_definition in variable_definitions:
73-
org_name = variable_definition.variable.name.value
74-
name = self._process_name(org_name)
75-
annotation = self._parse_type_node(variable_definition.type)
76-
77-
arguments.args.append(generate_arg(name, annotation))
78-
dict_.keys.append(generate_constant(org_name))
79-
dict_.values.append(generate_name(name))
80-
return arguments, dict_

ariadne_codegen/generators/package.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,7 @@
33
from pathlib import Path
44
from typing import Dict, List, Optional
55

6-
from graphql import (
7-
FragmentDefinitionNode,
8-
GraphQLEnumType,
9-
GraphQLInputObjectType,
10-
GraphQLSchema,
11-
OperationDefinitionNode,
12-
)
6+
from graphql import FragmentDefinitionNode, GraphQLSchema, OperationDefinitionNode
137

148
from ..exceptions import ParsingError
159
from .arguments import ArgumentsGenerator
@@ -98,7 +92,9 @@ def __init__(
9892
self.arguments_generator = (
9993
arguments_generator
10094
if arguments_generator
101-
else ArgumentsGenerator(convert_to_snake_case=self.convert_to_snake_case)
95+
else ArgumentsGenerator(
96+
schema=self.schema, convert_to_snake_case=self.convert_to_snake_case
97+
)
10298
)
10399
self.input_types_generator = (
104100
input_types_generator
@@ -208,23 +204,16 @@ def _validate_unique_file_names(self):
208204
def _generate_client(self):
209205
client_file_path = self.package_path / f"{self.client_file_name}.py"
210206

211-
input_types = []
212-
enums = []
213-
for type_ in self.arguments_generator.used_types:
214-
if isinstance(self.schema.type_map[type_], GraphQLInputObjectType):
215-
input_types.append(type_)
216-
elif isinstance(self.schema.type_map[type_], GraphQLEnumType):
217-
enums.append(type_)
218-
else:
219-
raise ParsingError(f"Argument type {type_} not found in schema.")
220-
221207
self.client_generator.add_import(
222-
names=input_types, from_=self.input_types_module_name, level=1
208+
names=self.arguments_generator.get_used_inputs(),
209+
from_=self.input_types_module_name,
210+
level=1,
223211
)
224212
self.client_generator.add_import(
225-
names=enums, from_=self.enums_module_name, level=1
213+
names=self.arguments_generator.get_used_enums(),
214+
from_=self.enums_module_name,
215+
level=1,
226216
)
227-
228217
self.client_generator.add_import(
229218
names=[self.base_client_name],
230219
from_=self.base_client_file_path.stem,

tests/generators/test_arguments_generator.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import ast
22

3-
from graphql import OperationDefinitionNode, parse
3+
from graphql import GraphQLSchema, OperationDefinitionNode, build_schema, parse
44

55
from ariadne_codegen.generators.arguments import ArgumentsGenerator
6-
from ariadne_codegen.generators.constants import OPTIONAL
6+
from ariadne_codegen.generators.constants import ANY, OPTIONAL
77

88
from ..utils import compare_ast
99

@@ -16,7 +16,19 @@ def _get_variable_definitions_from_query_str(query: str):
1616

1717

1818
def test_generate_returns_arguments_with_correct_non_optional_names_and_annotations():
19-
generator = ArgumentsGenerator()
19+
schema_str = """
20+
schema { query: Query }
21+
type Query { _skip: ID! }
22+
23+
input CustomInputType {
24+
fieldA: Int!
25+
fieldB: Float!
26+
fieldC: String!
27+
fieldD: Boolean!
28+
}
29+
"""
30+
schema = build_schema(schema_str)
31+
generator = ArgumentsGenerator(schema=schema)
2032
query = (
2133
"query q($id: ID!, $name: String!, $amount: Int!, $val: Float!, "
2234
"$flag: Boolean!, $custom_input: CustomInputType!) {r}"
@@ -43,7 +55,12 @@ def test_generate_returns_arguments_with_correct_non_optional_names_and_annotati
4355

4456

4557
def test_generate_returns_arguments_with_correct_optional_annotation():
46-
generator = ArgumentsGenerator()
58+
schema_str = """
59+
schema { query: Query }
60+
type Query { _skip: ID! }
61+
"""
62+
schema = build_schema(schema_str)
63+
generator = ArgumentsGenerator(schema=schema)
4764
query = "query q($id: ID) {r}"
4865
variable_definitions = _get_variable_definitions_from_query_str(query)
4966

@@ -62,7 +79,7 @@ def test_generate_returns_arguments_with_correct_optional_annotation():
6279

6380

6481
def test_generate_returns_arguments_with_only_self_argument_without_annotation():
65-
generator = ArgumentsGenerator()
82+
generator = ArgumentsGenerator(schema=GraphQLSchema())
6683
query = "query q {r}"
6784
variable_definitions = _get_variable_definitions_from_query_str(query)
6885

@@ -76,18 +93,27 @@ def test_generate_returns_arguments_with_only_self_argument_without_annotation()
7693

7794

7895
def test_generate_saves_used_non_scalar_types():
79-
generator = ArgumentsGenerator()
96+
schema_str = """
97+
schema { query: Query }
98+
type Query { _skip: String! }
99+
100+
input Type1 { fieldA: Int! }
101+
input Type2 { fieldB: Int! }
102+
"""
103+
schema = build_schema(schema_str)
104+
generator = ArgumentsGenerator(schema=schema)
80105
query = "query q($a1: String!, $a2: String, $a3: Type1!, $a4: Type2) {r}"
81106
variable_definitions = _get_variable_definitions_from_query_str(query)
82107

83108
generator.generate(variable_definitions)
84109

85-
assert len(generator.used_types) == 2
86-
assert generator.used_types == ["Type1", "Type2"]
110+
used_inputs = generator.get_used_inputs()
111+
assert len(used_inputs) == 2
112+
assert used_inputs == ["Type1", "Type2"]
87113

88114

89115
def test_generate_returns_arguments_and_dictionary_with_snake_case_names():
90-
generator = ArgumentsGenerator(convert_to_snake_case=True)
116+
generator = ArgumentsGenerator(schema=GraphQLSchema(), convert_to_snake_case=True)
91117
query = "query q($camelCase: String!, $snake_case: String!) {r}"
92118
variable_definitions = _get_variable_definitions_from_query_str(query)
93119

@@ -111,3 +137,34 @@ def test_generate_returns_arguments_and_dictionary_with_snake_case_names():
111137

112138
assert compare_ast(arguments, expected_arguments)
113139
assert compare_ast(arguments_dict, expected_arguments_dict)
140+
141+
142+
def test_generate_returns_arguments_with_used_custom_scalar():
143+
schema_str = """
144+
schema { query: Query }
145+
type Query { _skip: String! }
146+
scalar CustomScalar
147+
"""
148+
generator = ArgumentsGenerator(schema=build_schema(schema_str))
149+
query_str = "query q($arg: CustomScalar!) {r}"
150+
151+
expected_arguments = ast.arguments(
152+
posonlyargs=[],
153+
args=[
154+
ast.arg(arg="self"),
155+
ast.arg(arg="arg", annotation=ast.Name(id=ANY)),
156+
],
157+
kwonlyargs=[],
158+
kw_defaults=[],
159+
defaults=[],
160+
)
161+
expected_arguments_dict = ast.Dict(
162+
keys=[ast.Constant(value="arg")], values=[ast.Name(id="arg")]
163+
)
164+
165+
arguments, arguments_dict = generator.generate(
166+
_get_variable_definitions_from_query_str(query_str)
167+
)
168+
169+
assert compare_ast(arguments, expected_arguments)
170+
assert compare_ast(arguments_dict, expected_arguments_dict)

0 commit comments

Comments
 (0)