Skip to content

Commit d7c764c

Browse files
authored
Merge pull request #113 from strue36/escape-keywords
Escape keywords.
2 parents 5052660 + 0cd5666 commit d7c764c

File tree

11 files changed

+159
-27
lines changed

11 files changed

+159
-27
lines changed

Diff for: CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## UNRELEASED
44

55
- Changed logic how custom scalar imports are generated. Deprecated `import_` key.
6+
- Added escaping of GraphQL names which are Python keywords by appending `_` to them.
67

78

89
## 0.5.0 (2023-04-05)

Diff for: ariadne_codegen/client_generators/arguments.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
from ..exceptions import ParsingError
2727
from ..plugins.manager import PluginManager
28-
from ..utils import str_to_snake_case
28+
from ..utils import process_name
2929
from .constants import ANY, OPTIONAL, SIMPLE_TYPE_MAP
3030
from .scalars import ScalarData
3131

@@ -57,7 +57,7 @@ def generate(
5757
dict_ = generate_dict()
5858
for variable_definition in variable_definitions:
5959
org_name = variable_definition.variable.name.value
60-
name = self._process_name(org_name)
60+
name = process_name(org_name, self.convert_to_snake_case)
6161
annotation, used_custom_scalar = self._parse_type_node(
6262
variable_definition.type
6363
)
@@ -94,11 +94,6 @@ def get_used_inputs(self) -> List[str]:
9494
def get_used_custom_scalars(self) -> List[str]:
9595
return self._used_custom_scalars
9696

97-
def _process_name(self, name: str) -> str:
98-
if self.convert_to_snake_case:
99-
return str_to_snake_case(name)
100-
return name
101-
10297
def _parse_type_node(
10398
self,
10499
node: Union[NamedTypeNode, ListTypeNode, NonNullTypeNode, TypeNode],

Diff for: ariadne_codegen/client_generators/input_types.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
generate_module,
2121
)
2222
from ..plugins.manager import PluginManager
23-
from ..utils import str_to_snake_case
23+
from ..utils import process_name
2424
from .constants import (
2525
ANY,
2626
BASE_MODEL_CLASS_NAME,
@@ -109,7 +109,7 @@ def _parse_input_definition(
109109
)
110110

111111
for lineno, (org_name, field) in enumerate(definition.fields.items(), start=1):
112-
name = self._process_field_name(org_name)
112+
name = process_name(org_name, self.convert_to_snake_case)
113113
annotation, field_type = parse_input_field_type(
114114
field.type, custom_scalars=self.custom_scalars
115115
)
@@ -140,11 +140,6 @@ def _parse_input_definition(
140140

141141
return class_def
142142

143-
def _process_field_name(self, name: str) -> str:
144-
if self.convert_to_snake_case:
145-
return str_to_snake_case(name)
146-
return name
147-
148143
def _process_field_value(
149144
self,
150145
field_implementation: ast.AnnAssign,

Diff for: ariadne_codegen/client_generators/package.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..codegen import generate_import_from
99
from ..exceptions import ParsingError
1010
from ..plugins.manager import PluginManager
11-
from ..utils import ast_to_str, str_to_pascal_case, str_to_snake_case
11+
from ..utils import ast_to_str, process_name, str_to_pascal_case
1212
from .arguments import ArgumentsGenerator
1313
from .client import ClientGenerator
1414
from .constants import (
@@ -171,7 +171,7 @@ def add_operation(self, definition: OperationDefinitionNode):
171171
raise ParsingError("Query without name.")
172172

173173
return_type_name = str_to_pascal_case(name.value)
174-
method_name = str_to_snake_case(name.value)
174+
method_name = process_name(name.value, convert_to_snake_case=True)
175175
module_name = method_name
176176
file_name = f"{module_name}.py"
177177

Diff for: ariadne_codegen/client_generators/result_types.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434
from ..exceptions import NotSupported, ParsingError
3535
from ..plugins.manager import PluginManager
36-
from ..utils import str_to_pascal_case, str_to_snake_case
36+
from ..utils import process_name, str_to_pascal_case
3737
from .constants import (
3838
ANY,
3939
BASE_MODEL_CLASS_NAME,
@@ -270,11 +270,9 @@ def _get_field_name(self, field: FieldNode) -> str:
270270
return field.name.value
271271

272272
def _process_field_name(self, name: str) -> str:
273-
if self.convert_to_snake_case:
274-
if name == TYPENAME_FIELD_NAME:
275-
return "__typename__"
276-
return str_to_snake_case(name)
277-
return name
273+
if self.convert_to_snake_case and name == TYPENAME_FIELD_NAME:
274+
return "__typename__"
275+
return process_name(name, self.convert_to_snake_case)
278276

279277
def _get_field_from_schema(self, type_name: str, field_name: str) -> GraphQLField:
280278
try:

Diff for: ariadne_codegen/utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ast
22
import re
3+
from keyword import iskeyword
34
from textwrap import indent
45

56
import isort
@@ -70,3 +71,14 @@ def format_multiline_strings(source: str) -> str:
7071
formatted = convert_to_multiline_string(orginal_str, variable_indent_size)
7172
formatted_source = formatted_source.replace(orginal_str, formatted)
7273
return formatted_source
74+
75+
76+
def process_name(name: str, convert_to_snake_case: bool) -> str:
77+
"""Processes the GraphQL name to remove keywords
78+
and optionally convert to snake_case."""
79+
processed_name = name
80+
if convert_to_snake_case:
81+
processed_name = str_to_snake_case(processed_name)
82+
if iskeyword(processed_name):
83+
processed_name += "_"
84+
return processed_name

Diff for: tests/client_generators/input_types_generator/test_names.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
FIELD_CLASS,
99
)
1010
from ariadne_codegen.client_generators.input_types import InputTypesGenerator
11+
from ariadne_codegen.utils import ast_to_str
1112

12-
from ...utils import compare_ast, get_class_def
13+
from ...utils import compare_ast, get_assignment_target_names, get_class_def
1314

1415

1516
@pytest.mark.parametrize(
@@ -137,3 +138,26 @@ def test_generate_returns_module_with_fields_names_converted_to_snake_case(
137138

138139
class_def = get_class_def(module)
139140
assert compare_ast(class_def, expected_class_def)
141+
142+
143+
def test_generate_returns_module_with_valid_field_names():
144+
schema = """
145+
input KeywordInput {
146+
in: String!
147+
from: String!
148+
and: String!
149+
}
150+
"""
151+
152+
generator = InputTypesGenerator(
153+
schema=build_ast_schema(parse(schema)), enums_module="enums"
154+
)
155+
156+
module = generator.generate()
157+
158+
parsed = ast.parse(
159+
ast_to_str(module)
160+
) # Round trip because invalid identifiers get picked up in parse
161+
class_def = get_class_def(parsed)
162+
field_names = get_assignment_target_names(class_def)
163+
assert field_names == {"in_", "from_", "and_"}

Diff for: tests/client_generators/result_types_generator/test_names.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from graphql import OperationDefinitionNode, build_ast_schema, parse
66

77
from ariadne_codegen.client_generators.result_types import ResultTypesGenerator
8+
from ariadne_codegen.utils import ast_to_str
89

9-
from ...utils import compare_ast, get_class_def
10+
from ...utils import compare_ast, get_assignment_target_names, get_class_def
1011
from .schema import SCHEMA_STR
1112

1213

@@ -104,3 +105,30 @@ def test_generate_returns_module_with_handled_graphql_alias(
104105
assert len(class_def.body) == 1
105106
field_implementation = class_def.body[0]
106107
assert compare_ast(field_implementation, expected_field_implementation)
108+
109+
110+
def test_generate_returns_module_with_valid_field_names():
111+
query_str = """
112+
query CustomQuery {
113+
camelCaseQuery {
114+
in: id
115+
}
116+
}
117+
"""
118+
generator = ResultTypesGenerator(
119+
schema=build_ast_schema(parse(SCHEMA_STR)),
120+
operation_definition=cast(
121+
OperationDefinitionNode, parse(query_str).definitions[0]
122+
),
123+
enums_module_name="enums",
124+
convert_to_snake_case=True,
125+
)
126+
127+
module = generator.generate()
128+
129+
parsed = ast.parse(
130+
ast_to_str(module)
131+
) # Round trip because invalid identifiers get picked up in parse
132+
class_def = get_class_def(parsed, name_filter="CustomQueryCamelCaseQuery")
133+
field_names = get_assignment_target_names(class_def)
134+
assert field_names == {"in_"}

Diff for: tests/client_generators/test_arguments_generator.py

+32
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,38 @@ def test_generate_returns_arguments_and_dictionary_with_snake_case_names():
179179
assert compare_ast(arguments_dict, expected_arguments_dict)
180180

181181

182+
def test_generate_returns_arguments_and_dictionary_with_valid_names():
183+
generator = ArgumentsGenerator(schema=GraphQLSchema(), convert_to_snake_case=True)
184+
query = "query q($from: String!, $and: String!, $in: String!) {r}"
185+
variable_definitions = _get_variable_definitions_from_query_str(query)
186+
187+
arguments, arguments_dict = generator.generate(variable_definitions)
188+
189+
expected_arguments = ast.arguments(
190+
posonlyargs=[],
191+
args=[
192+
ast.arg(arg="self"),
193+
ast.arg(arg="from_", annotation=ast.Name(id="str")),
194+
ast.arg(arg="and_", annotation=ast.Name(id="str")),
195+
ast.arg(arg="in_", annotation=ast.Name(id="str")),
196+
],
197+
kwonlyargs=[],
198+
kw_defaults=[],
199+
defaults=[],
200+
)
201+
expected_arguments_dict = ast.Dict(
202+
keys=[
203+
ast.Constant(value="from"),
204+
ast.Constant(value="and"),
205+
ast.Constant(value="in"),
206+
],
207+
values=[ast.Name(id="from_"), ast.Name(id="and_"), ast.Name(id="in_")],
208+
)
209+
210+
assert compare_ast(arguments, expected_arguments)
211+
assert compare_ast(arguments_dict, expected_arguments_dict)
212+
213+
182214
def test_generate_returns_arguments_with_not_mapped_custom_scalar():
183215
schema_str = """
184216
schema { query: Query }

Diff for: tests/client_generators/test_package_generator.py

+35
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
from datetime import datetime
23
from textwrap import dedent, indent
34

@@ -14,6 +15,8 @@
1415
from ariadne_codegen.client_generators.scalars import ScalarData
1516
from ariadne_codegen.exceptions import ParsingError
1617

18+
from ..utils import get_class_def
19+
1720
SCHEMA_STR = """
1821
schema {
1922
query: Query
@@ -308,6 +311,38 @@ async def custom_query(self, id: str, param: Optional[str] = None) -> CustomQuer
308311
assert "from .custom_query import CustomQuery" in client_content
309312

310313

314+
def test_generate_creates_client_with_valid_method_names(tmp_path):
315+
package_name = "test_graphql_client"
316+
generator = PackageGenerator(
317+
package_name,
318+
tmp_path.as_posix(),
319+
build_ast_schema(parse(SCHEMA_STR)),
320+
async_client=False,
321+
)
322+
query_str = """
323+
query From($id: ID!, $param: String) {
324+
query1(id: $id) {
325+
field1
326+
field2 {
327+
fieldb
328+
}
329+
field3
330+
}
331+
}
332+
"""
333+
334+
generator.add_operation(parse(query_str).definitions[0])
335+
generator.generate()
336+
337+
client_file_path = tmp_path / package_name / "client.py"
338+
with client_file_path.open() as client_file:
339+
client_content = client_file.read()
340+
parsed = ast.parse(client_content)
341+
class_def = get_class_def(parsed)
342+
function = [x for x in class_def.body if isinstance(x, ast.FunctionDef)][0]
343+
assert function.name == "from_"
344+
345+
311346
def test_generate_with_conflicting_query_name_raises_parsing_error(tmp_path):
312347
generator = PackageGenerator(
313348
"test_graphql_client",

Diff for: tests/utils.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ast
22
from itertools import zip_longest
33
from textwrap import dedent
4-
from typing import List, Optional, Union, cast
4+
from typing import List, Optional, Set, Union, cast
55

66

77
def compare_ast(
@@ -24,16 +24,28 @@ def compare_ast(
2424
return node1 == node2
2525

2626

27-
def get_class_def(module: ast.Module, class_index=0) -> Optional[ast.ClassDef]:
27+
def get_class_def(
28+
module: ast.Module, class_index=0, name_filter=None
29+
) -> Optional[ast.ClassDef]:
2830
found = 0
2931
for expr in module.body:
30-
if isinstance(expr, ast.ClassDef):
32+
if isinstance(expr, ast.ClassDef) and (
33+
not name_filter or expr.name == name_filter
34+
):
3135
found += 1
3236
if found - 1 == class_index:
3337
return expr
3438
return None
3539

3640

41+
def get_assignment_target_names(class_def: ast.ClassDef) -> Set[str]:
42+
return {
43+
x.target.id
44+
for x in class_def.body
45+
if isinstance(x, ast.AnnAssign) and isinstance(x.target, ast.Name)
46+
}
47+
48+
3749
def filter_ast_objects(module: ast.Module, ast_class) -> List[ast.AST]:
3850
return [expr for expr in module.body if isinstance(expr, ast_class)]
3951

0 commit comments

Comments
 (0)