Skip to content

Commit b728881

Browse files
authored
Merge pull request #129 from mirumee/process_name_hook
Add process_name hook
2 parents 5adba4f + 10c60f8 commit b728881

20 files changed

+361
-178
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- Added escaping of GraphQL names which are Python keywords by appending `_` to them.
77
- Fixed parsing of list variables.
88
- Changed base clients to remove unset arguments and input fields from variables payload.
9+
- Added `process_name` plugin hook.
910

1011

1112
## 0.5.0 (2023-04-05)

PLUGINS.md

+9-1
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,14 @@ def generate_init_code(self, generated_code: str) -> str:
284284

285285
Hook executed on generation of init code. Result is used as content of `__init__.py`.
286286

287+
### process_name
288+
289+
```py
290+
def process_name(self, name: str, node: Optional[Node] = None) -> str:
291+
```
292+
293+
Hook executed on processing of GraphQL field, argument or operation name.
294+
287295

288296
## Example
289297

@@ -298,7 +306,7 @@ from ariadne_codegen.plugins.base import Plugin
298306
class VersionPlugin(Plugin):
299307
def generate_init_module(self, module: ast.Module) -> ast.Module:
300308
version = (
301-
self.config_dict.get("tools", {})
309+
self.config_dict.get("tool", {})
302310
.get("version_plugin", {})
303311
.get("version", "0.1")
304312
)

ariadne_codegen/client_generators/arguments.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ def generate(
5858
dict_ = generate_dict()
5959
for variable_definition in variable_definitions:
6060
org_name = variable_definition.variable.name.value
61-
name = process_name(org_name, self.convert_to_snake_case)
61+
name = process_name(
62+
org_name,
63+
convert_to_snake_case=self.convert_to_snake_case,
64+
plugin_manager=self.plugin_manager,
65+
node=variable_definition,
66+
)
6267
annotation, used_custom_scalar = self._parse_type_node(
6368
variable_definition.type
6469
)

ariadne_codegen/client_generators/input_types.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,12 @@ def _parse_input_definition(
109109
)
110110

111111
for lineno, (org_name, field) in enumerate(definition.fields.items(), start=1):
112-
name = process_name(org_name, self.convert_to_snake_case)
112+
name = process_name(
113+
org_name,
114+
convert_to_snake_case=self.convert_to_snake_case,
115+
plugin_manager=self.plugin_manager,
116+
node=field,
117+
)
113118
annotation, field_type = parse_input_field_type(
114119
field.type, custom_scalars=self.custom_scalars
115120
)

ariadne_codegen/client_generators/package.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,12 @@ def add_operation(self, definition: OperationDefinitionNode):
180180
raise ParsingError("Query without name.")
181181

182182
return_type_name = str_to_pascal_case(name.value)
183-
method_name = process_name(name.value, convert_to_snake_case=True)
183+
method_name = process_name(
184+
name.value,
185+
convert_to_snake_case=True,
186+
plugin_manager=self.plugin_manager,
187+
node=definition,
188+
)
184189
module_name = method_name
185190
file_name = f"{module_name}.py"
186191

ariadne_codegen/client_generators/result_types.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _parse_type_definition(
185185
start=1,
186186
):
187187
field_name = self._get_field_name(field)
188-
name = self._process_field_name(field_name)
188+
name = self._process_field_name(field_name, field=field)
189189
field_definition = self._get_field_from_schema(type_name, field.name.value)
190190
annotation, field_types_names = parse_operation_field(
191191
field=field,
@@ -269,10 +269,15 @@ def _get_field_name(self, field: FieldNode) -> str:
269269
return field.alias.value
270270
return field.name.value
271271

272-
def _process_field_name(self, name: str) -> str:
272+
def _process_field_name(self, name: str, field: FieldNode) -> str:
273273
if self.convert_to_snake_case and name == TYPENAME_FIELD_NAME:
274274
return "__typename__"
275-
return process_name(name, self.convert_to_snake_case)
275+
return process_name(
276+
name,
277+
convert_to_snake_case=self.convert_to_snake_case,
278+
plugin_manager=self.plugin_manager,
279+
node=field,
280+
)
276281

277282
def _get_field_from_schema(self, type_name: str, field_name: str) -> GraphQLField:
278283
try:

ariadne_codegen/plugins/base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import ast
2-
from typing import Dict, Tuple, Union
2+
from typing import Dict, Optional, Tuple, Union
33

44
from graphql import (
55
FieldNode,
66
GraphQLEnumType,
77
GraphQLInputField,
88
GraphQLInputObjectType,
99
GraphQLSchema,
10+
Node,
1011
OperationDefinitionNode,
1112
SelectionSetNode,
1213
VariableDefinitionNode,
@@ -144,3 +145,7 @@ def generate_scalars_code(self, generated_code: str) -> str:
144145

145146
def generate_init_code(self, generated_code: str) -> str:
146147
return generated_code
148+
149+
# pylint: disable=unused-argument
150+
def process_name(self, name: str, node: Optional[Node] = None) -> str:
151+
return name

ariadne_codegen/plugins/manager.py

+4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
GraphQLInputField,
88
GraphQLInputObjectType,
99
GraphQLSchema,
10+
Node,
1011
OperationDefinitionNode,
1112
SelectionSetNode,
1213
VariableDefinitionNode,
@@ -186,3 +187,6 @@ def generate_scalars_code(self, generated_code: str) -> str:
186187

187188
def generate_init_code(self, generated_code: str) -> str:
188189
return self._apply_plugins_on_object("generate_init_code", generated_code)
190+
191+
def process_name(self, name: str, node: Optional[Node] = None) -> str:
192+
return self._apply_plugins_on_object("process_name", name, node=node)

ariadne_codegen/utils.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
import re
33
from keyword import iskeyword
44
from textwrap import indent
5+
from typing import Optional
56

67
import isort
78
from autoflake import fix_code # type: ignore
89
from black import Mode, format_str
10+
from graphql import Node
11+
12+
from .plugins.manager import PluginManager
913

1014

1115
def ast_to_str(
@@ -73,12 +77,19 @@ def format_multiline_strings(source: str) -> str:
7377
return formatted_source
7478

7579

76-
def process_name(name: str, convert_to_snake_case: bool) -> str:
80+
def process_name(
81+
name: str,
82+
convert_to_snake_case: bool,
83+
plugin_manager: Optional[PluginManager] = None,
84+
node: Optional[Node] = None,
85+
) -> str:
7786
"""Processes the GraphQL name to remove keywords
7887
and optionally convert to snake_case."""
7988
processed_name = name
8089
if convert_to_snake_case:
8190
processed_name = str_to_snake_case(processed_name)
8291
if iskeyword(processed_name):
8392
processed_name += "_"
93+
if plugin_manager:
94+
processed_name = plugin_manager.process_name(processed_name, node=node)
8495
return processed_name

tests/client_generators/input_types_generator/test_plugin_hooks.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from ariadne_codegen.client_generators.input_types import InputTypesGenerator
1010

1111

12-
def test_generator_triggers_generate_input_class_hook_for_every_input_type(mocker):
12+
def test_generator_triggers_generate_input_class_hook_for_every_input_type(
13+
mocked_plugin_manager,
14+
):
1315
schema_str = """
1416
input TestInputA {
1517
fieldA: String!
@@ -19,7 +21,6 @@ def test_generator_triggers_generate_input_class_hook_for_every_input_type(mocke
1921
fieldB: Int!
2022
}
2123
"""
22-
mocked_plugin_manager = mocker.MagicMock()
2324

2425
InputTypesGenerator(
2526
schema=build_ast_schema(parse(schema_str)),
@@ -37,7 +38,9 @@ def test_generator_triggers_generate_input_class_hook_for_every_input_type(mocke
3738
assert call1_input_type.name == "TestInputB"
3839

3940

40-
def test_generator_triggers_generate_input_field_hook_for_every_input_field(mocker):
41+
def test_generator_triggers_generate_input_field_hook_for_every_input_field(
42+
mocked_plugin_manager,
43+
):
4144
schema_str = """
4245
input TestInputAB {
4346
fieldA: String!
@@ -48,7 +51,6 @@ def test_generator_triggers_generate_input_field_hook_for_every_input_field(mock
4851
fieldC: Int!
4952
}
5053
"""
51-
mocked_plugin_manager = mocker.MagicMock()
5254

5355
InputTypesGenerator(
5456
schema=build_ast_schema(parse(schema_str)),
@@ -66,8 +68,7 @@ def test_generator_triggers_generate_input_field_hook_for_every_input_field(mock
6668
assert mock_calls[2].kwargs["field_name"] == "fieldC"
6769

6870

69-
def test_generate_triggers_generate_inputs_module_hook(mocker):
70-
mocked_plugin_manager = mocker.MagicMock()
71+
def test_generate_triggers_generate_inputs_module_hook(mocked_plugin_manager):
7172
generator = InputTypesGenerator(
7273
schema=GraphQLSchema(),
7374
enums_module="enums",
@@ -77,3 +78,27 @@ def test_generate_triggers_generate_inputs_module_hook(mocker):
7778
generator.generate()
7879

7980
assert mocked_plugin_manager.generate_inputs_module.called
81+
82+
83+
def test_generate_triggers_process_name_hook_for_every_field(mocked_plugin_manager):
84+
schema_str = """
85+
input TestInputAB {
86+
fieldA: String!
87+
fieldB: String!
88+
}
89+
90+
input TestInputC {
91+
fieldC: Int!
92+
}
93+
"""
94+
95+
InputTypesGenerator(
96+
schema=build_ast_schema(parse(schema_str)),
97+
enums_module="enums",
98+
convert_to_snake_case=False,
99+
plugin_manager=mocked_plugin_manager,
100+
)
101+
102+
assert mocked_plugin_manager.process_name.call_count == 3
103+
mock_calls = mocked_plugin_manager.process_name.mock_calls
104+
assert {call.args[0] for call in mock_calls} == {"fieldA", "fieldB", "fieldC"}

tests/client_generators/result_types_generator/test_plugin_hooks.py

+41-8
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
from .schema import SCHEMA_STR
88

99

10-
def test_generate_triggers_generate_result_types_module_hook(mocker):
10+
def test_generate_triggers_generate_result_types_module_hook(mocked_plugin_manager):
1111
query_str = "query CustomQuery { camelCaseQuery { id } }"
12-
mocked_plugin_manager = mocker.MagicMock()
1312
generator = ResultTypesGenerator(
1413
schema=build_ast_schema(parse(SCHEMA_STR)),
1514
operation_definition=cast(
@@ -24,9 +23,10 @@ def test_generate_triggers_generate_result_types_module_hook(mocker):
2423
assert mocked_plugin_manager.generate_result_types_module.called
2524

2625

27-
def test_get_operation_as_str_triggers_generate_operation_str_hook(mocker):
26+
def test_get_operation_as_str_triggers_generate_operation_str_hook(
27+
mocked_plugin_manager,
28+
):
2829
query_str = "query CustomQuery { camelCaseQuery { id } }"
29-
mocked_plugin_manager = mocker.MagicMock()
3030
generator = ResultTypesGenerator(
3131
schema=build_ast_schema(parse(SCHEMA_STR)),
3232
operation_definition=cast(
@@ -41,9 +41,10 @@ def test_get_operation_as_str_triggers_generate_operation_str_hook(mocker):
4141
assert mocked_plugin_manager.generate_operation_str.called
4242

4343

44-
def test_generator_triggers_generate_result_class_hook_for_every_class(mocker):
44+
def test_generator_triggers_generate_result_class_hook_for_every_class(
45+
mocked_plugin_manager,
46+
):
4547
query_str = "query CustomQuery { camelCaseQuery { id } }"
46-
mocked_plugin_manager = mocker.MagicMock()
4748

4849
ResultTypesGenerator(
4950
schema=build_ast_schema(parse(SCHEMA_STR)),
@@ -60,7 +61,9 @@ def test_generator_triggers_generate_result_class_hook_for_every_class(mocker):
6061
} == {"CustomQuery", "CustomQueryCamelCaseQuery"}
6162

6263

63-
def test_generator_triggers_generate_result_field_hook_for_every_field(mocker):
64+
def test_generator_triggers_generate_result_field_hook_for_every_field(
65+
mocked_plugin_manager,
66+
):
6467
query_str = """
6568
query CustomQuery {
6669
camelCaseQuery {
@@ -71,7 +74,6 @@ def test_generator_triggers_generate_result_field_hook_for_every_field(mocker):
7174
}
7275
}
7376
"""
74-
mocked_plugin_manager = mocker.MagicMock()
7577

7678
ResultTypesGenerator(
7779
schema=build_ast_schema(parse(SCHEMA_STR)),
@@ -87,3 +89,34 @@ def test_generator_triggers_generate_result_field_hook_for_every_field(mocker):
8789
c.kwargs["field"].name.value
8890
for c in mocked_plugin_manager.generate_result_field.mock_calls
8991
} == {"camelCaseQuery", "id", "field1", "fielda"}
92+
93+
94+
def test_generator_triggers_process_name_hook_for_every_field(mocked_plugin_manager):
95+
query_str = """
96+
query CustomQuery {
97+
camelCaseQuery {
98+
id
99+
field1 {
100+
fielda
101+
}
102+
}
103+
}
104+
"""
105+
106+
ResultTypesGenerator(
107+
schema=build_ast_schema(parse(SCHEMA_STR)),
108+
operation_definition=cast(
109+
OperationDefinitionNode, parse(query_str).definitions[0]
110+
),
111+
enums_module_name="enums",
112+
convert_to_snake_case=False,
113+
plugin_manager=mocked_plugin_manager,
114+
)
115+
116+
assert mocked_plugin_manager.process_name.call_count == 4
117+
assert {c.args[0] for c in mocked_plugin_manager.process_name.mock_calls} == {
118+
"camelCaseQuery",
119+
"id",
120+
"field1",
121+
"fielda",
122+
}

tests/client_generators/test_arguments_generator.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,7 @@ def test_generate_returns_arguments_with_custom_scalar_and_used_serialize_method
320320
assert compare_ast(arguments_dict, expected_arguments_dict)
321321

322322

323-
def test_generate_triggers_generate_arguments_hook(mocker):
324-
mocked_plugin_manager = mocker.MagicMock()
323+
def test_generate_triggers_generate_arguments_hook(mocked_plugin_manager):
325324
schema_str = """
326325
schema { query: Query }
327326
type Query { _skip: String! }
@@ -337,8 +336,7 @@ def test_generate_triggers_generate_arguments_hook(mocker):
337336
assert mocked_plugin_manager.generate_arguments.called
338337

339338

340-
def test_generate_triggers_generate_arguments_dict_hook(mocker):
341-
mocked_plugin_manager = mocker.MagicMock()
339+
def test_generate_triggers_generate_arguments_dict_hook(mocked_plugin_manager):
342340
schema_str = """
343341
schema { query: Query }
344342
type Query { _skip: String! }
@@ -352,3 +350,25 @@ def test_generate_triggers_generate_arguments_dict_hook(mocker):
352350
)
353351

354352
assert mocked_plugin_manager.generate_arguments_dict.called
353+
354+
355+
def test_generate_triggers_process_name_hook_for_every_arg(mocked_plugin_manager):
356+
schema_str = """
357+
schema { query: Query }
358+
type Query { _skip: String! }
359+
"""
360+
generator = ArgumentsGenerator(
361+
schema=build_schema(schema_str), plugin_manager=mocked_plugin_manager
362+
)
363+
364+
generator.generate(
365+
_get_variable_definitions_from_query_str(
366+
"query q($arg1: String!, $arg2: String) { _skip }"
367+
)
368+
)
369+
370+
assert mocked_plugin_manager.process_name.call_count == 2
371+
name1 = mocked_plugin_manager.process_name.mock_calls[0].args[0]
372+
name2 = mocked_plugin_manager.process_name.mock_calls[1].args[0]
373+
assert name1 == "arg1"
374+
assert name2 == "arg2"

0 commit comments

Comments
 (0)