Skip to content

Commit b2e9ff6

Browse files
committed
Add tests that check names are as expected
1 parent 5abedd5 commit b2e9ff6

File tree

5 files changed

+136
-5
lines changed

5 files changed

+136
-5
lines changed

Diff for: tests/client_generators/input_types_generator/test_names.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
FIELD_CLASS,
99
)
1010
from ariadne_codegen.client_generators.input_types import InputTypesGenerator
11+
from ariadne_codegen.utils import ast_to_str
1112

12-
from ...utils import compare_ast, get_class_def
13+
from ...utils import compare_ast, get_assignment_target_names, get_class_def
1314

1415

1516
@pytest.mark.parametrize(
@@ -137,3 +138,26 @@ def test_generate_returns_module_with_fields_names_converted_to_snake_case(
137138

138139
class_def = get_class_def(module)
139140
assert compare_ast(class_def, expected_class_def)
141+
142+
143+
def test_generate_returns_module_with_valid_field_names():
144+
schema = """
145+
input KeywordInput {
146+
in: String!
147+
from: String!
148+
and: String!
149+
}
150+
"""
151+
152+
generator = InputTypesGenerator(
153+
schema=build_ast_schema(parse(schema)), enums_module="enums"
154+
)
155+
156+
module = generator.generate()
157+
158+
parsed = ast.parse(
159+
ast_to_str(module)
160+
) # Round trip because invalid identifiers get picked up in parse
161+
class_def = get_class_def(parsed)
162+
field_names = get_assignment_target_names(class_def)
163+
assert field_names == {"in_", "from_", "and_"}

Diff for: tests/client_generators/result_types_generator/test_names.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from graphql import OperationDefinitionNode, build_ast_schema, parse
66

77
from ariadne_codegen.client_generators.result_types import ResultTypesGenerator
8+
from ariadne_codegen.utils import ast_to_str
89

9-
from ...utils import compare_ast, get_class_def
10+
from ...utils import compare_ast, get_assignment_target_names, get_class_def
1011
from .schema import SCHEMA_STR
1112

1213

@@ -104,3 +105,30 @@ def test_generate_returns_module_with_handled_graphql_alias(
104105
assert len(class_def.body) == 1
105106
field_implementation = class_def.body[0]
106107
assert compare_ast(field_implementation, expected_field_implementation)
108+
109+
110+
def test_generate_returns_module_with_valid_field_names():
111+
query_str = """
112+
query CustomQuery {
113+
camelCaseQuery {
114+
in: id
115+
}
116+
}
117+
"""
118+
generator = ResultTypesGenerator(
119+
schema=build_ast_schema(parse(SCHEMA_STR)),
120+
operation_definition=cast(
121+
OperationDefinitionNode, parse(query_str).definitions[0]
122+
),
123+
enums_module_name="enums",
124+
convert_to_snake_case=True,
125+
)
126+
127+
module = generator.generate()
128+
129+
parsed = ast.parse(
130+
ast_to_str(module)
131+
) # Round trip because invalid identifiers get picked up in parse
132+
class_def = get_class_def(parsed, name_filter="CustomQueryCamelCaseQuery")
133+
field_names = get_assignment_target_names(class_def)
134+
assert field_names == {"in_"}

Diff for: tests/client_generators/test_arguments_generator.py

+32
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,38 @@ def test_generate_returns_arguments_and_dictionary_with_snake_case_names():
179179
assert compare_ast(arguments_dict, expected_arguments_dict)
180180

181181

182+
def test_generate_returns_arguments_and_dictionary_with_valid_names():
183+
generator = ArgumentsGenerator(schema=GraphQLSchema(), convert_to_snake_case=True)
184+
query = "query q($from: String!, $and: String!, $in: String!) {r}"
185+
variable_definitions = _get_variable_definitions_from_query_str(query)
186+
187+
arguments, arguments_dict = generator.generate(variable_definitions)
188+
189+
expected_arguments = ast.arguments(
190+
posonlyargs=[],
191+
args=[
192+
ast.arg(arg="self"),
193+
ast.arg(arg="from_", annotation=ast.Name(id="str")),
194+
ast.arg(arg="and_", annotation=ast.Name(id="str")),
195+
ast.arg(arg="in_", annotation=ast.Name(id="str")),
196+
],
197+
kwonlyargs=[],
198+
kw_defaults=[],
199+
defaults=[],
200+
)
201+
expected_arguments_dict = ast.Dict(
202+
keys=[
203+
ast.Constant(value="from"),
204+
ast.Constant(value="and"),
205+
ast.Constant(value="in"),
206+
],
207+
values=[ast.Name(id="from_"), ast.Name(id="and_"), ast.Name(id="in_")],
208+
)
209+
210+
assert compare_ast(arguments, expected_arguments)
211+
assert compare_ast(arguments_dict, expected_arguments_dict)
212+
213+
182214
def test_generate_returns_arguments_with_not_mapped_custom_scalar():
183215
schema_str = """
184216
schema { query: Query }

Diff for: tests/client_generators/test_package_generator.py

+35
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
from datetime import datetime
23
from textwrap import dedent, indent
34

@@ -14,6 +15,8 @@
1415
from ariadne_codegen.client_generators.scalars import ScalarData
1516
from ariadne_codegen.exceptions import ParsingError
1617

18+
from ..utils import get_class_def
19+
1720
SCHEMA_STR = """
1821
schema {
1922
query: Query
@@ -308,6 +311,38 @@ async def custom_query(self, id: str, param: Optional[str] = None) -> CustomQuer
308311
assert "from .custom_query import CustomQuery" in client_content
309312

310313

314+
def test_generate_creates_client_with_valid_method_names(tmp_path):
315+
package_name = "test_graphql_client"
316+
generator = PackageGenerator(
317+
package_name,
318+
tmp_path.as_posix(),
319+
build_ast_schema(parse(SCHEMA_STR)),
320+
async_client=False,
321+
)
322+
query_str = """
323+
query From($id: ID!, $param: String) {
324+
query1(id: $id) {
325+
field1
326+
field2 {
327+
fieldb
328+
}
329+
field3
330+
}
331+
}
332+
"""
333+
334+
generator.add_operation(parse(query_str).definitions[0])
335+
generator.generate()
336+
337+
client_file_path = tmp_path / package_name / "client.py"
338+
with client_file_path.open() as client_file:
339+
client_content = client_file.read()
340+
parsed = ast.parse(client_content)
341+
class_def = get_class_def(parsed)
342+
function = [x for x in class_def.body if isinstance(x, ast.FunctionDef)][0]
343+
assert function.name == "from_"
344+
345+
311346
def test_generate_with_conflicting_query_name_raises_parsing_error(tmp_path):
312347
generator = PackageGenerator(
313348
"test_graphql_client",

Diff for: tests/utils.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ast
22
from itertools import zip_longest
33
from textwrap import dedent
4-
from typing import List, Optional, Union, cast
4+
from typing import List, Optional, Set, Union, cast
55

66

77
def compare_ast(
@@ -24,16 +24,28 @@ def compare_ast(
2424
return node1 == node2
2525

2626

27-
def get_class_def(module: ast.Module, class_index=0) -> Optional[ast.ClassDef]:
27+
def get_class_def(
28+
module: ast.Module, class_index=0, name_filter=None
29+
) -> Optional[ast.ClassDef]:
2830
found = 0
2931
for expr in module.body:
30-
if isinstance(expr, ast.ClassDef):
32+
if isinstance(expr, ast.ClassDef) and (
33+
not name_filter or expr.name == name_filter
34+
):
3135
found += 1
3236
if found - 1 == class_index:
3337
return expr
3438
return None
3539

3640

41+
def get_assignment_target_names(class_def: ast.ClassDef) -> Set[str]:
42+
return {
43+
x.target.id
44+
for x in class_def.body
45+
if isinstance(x, ast.AnnAssign) and isinstance(x.target, ast.Name)
46+
}
47+
48+
3749
def filter_ast_objects(module: ast.Module, ast_class) -> List[ast.AST]:
3850
return [expr for expr in module.body if isinstance(expr, ast_class)]
3951

0 commit comments

Comments
 (0)