Skip to content

Commit b24d10c

Browse files
authored
Merge pull request #196 from mirumee/pydantic_reserved_field_names
Handle field names reserved by pydantic
2 parents 05cc3a5 + 0e025c1 commit b24d10c

File tree

7 files changed

+25
-1
lines changed

7 files changed

+25
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- Changed generated client and models to use pydantic v2.
1212
- Changed custom scalars implementation to utilize pydantic's `BeforeValidator` and `PlainSerializer`. Added `scalars_module_name` option. Replaced `generate_scalars_parse_dict` and `generate_scalars_serialize_dict` with `generate_scalar_annotation` and `generate_scalar_imports` plugin hooks.
1313
- Fixed generating default values of input types from remote schemas.
14+
- Changed generating of input and result field names to add `_` to names reserved by pydantic.
1415

1516

1617
## 0.7.1 (2023-06-06)

ariadne_codegen/client_generators/input_types.py

+1
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def _parse_input_definition(
127127
plugin_manager=self.plugin_manager,
128128
node=field,
129129
trim_leading_underscore=True,
130+
handle_pydantic_resrved_field_names=True,
130131
)
131132
annotation, field_type = parse_input_field_type(
132133
field.type, custom_scalars=self.custom_scalars

ariadne_codegen/client_generators/result_types.py

+1
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def _process_field_name(self, name: str, field: FieldNode) -> str:
379379
plugin_manager=self.plugin_manager,
380380
node=field,
381381
trim_leading_underscore=True,
382+
handle_pydantic_resrved_field_names=True,
382383
)
383384

384385
def _get_field_from_schema(self, type_name: str, field_name: str) -> GraphQLField:

ariadne_codegen/utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@
88
from autoflake import fix_code # type: ignore
99
from black import Mode, format_str
1010
from graphql import Node
11+
from pydantic import BaseModel
1112

1213
from .plugins.manager import PluginManager
1314

15+
PYDANTIC_RESERVED_FIELD_NAMES = [
16+
name for name in dir(BaseModel) if not name.startswith("_")
17+
]
18+
1419

1520
def ast_to_str(
1621
ast_obj: ast.AST,
@@ -83,6 +88,7 @@ def process_name(
8388
plugin_manager: Optional[PluginManager] = None,
8489
node: Optional[Node] = None,
8590
trim_leading_underscore: bool = False,
91+
handle_pydantic_resrved_field_names: bool = False,
8692
) -> str:
8793
"""Processes the GraphQL name to remove keywords
8894
and optionally convert to snake_case."""
@@ -91,6 +97,11 @@ def process_name(
9197
processed_name = str_to_snake_case(processed_name)
9298
if iskeyword(processed_name):
9399
processed_name += "_"
100+
if (
101+
handle_pydantic_resrved_field_names
102+
and processed_name in PYDANTIC_RESERVED_FIELD_NAMES
103+
):
104+
processed_name += "_"
94105
if trim_leading_underscore:
95106
processed_name = processed_name.lstrip("_")
96107
if plugin_manager:

tests/client_generators/input_types_generator/test_names.py

+2
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def test_generate_returns_module_with_valid_field_names(
160160
_Bar: String!
161161
____baz_: String!
162162
_: String!
163+
schema: String!
163164
}
164165
"""
165166

@@ -186,4 +187,5 @@ def test_generate_returns_module_with_valid_field_names(
186187
"bar",
187188
"baz_",
188189
"underscore_named_field_",
190+
"schema_",
189191
}

tests/client_generators/result_types_generator/schema.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_Field5: String!
3434
scalarField: SCALARA
3535
_: String!
36+
schema: String!
3637
}
3738
3839
type CustomType1 {

tests/client_generators/result_types_generator/test_names.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def test_generate_returns_module_with_valid_field_names():
120120
_field4
121121
_Field5
122122
_
123+
schema
123124
}
124125
}
125126
"""
@@ -140,4 +141,10 @@ def test_generate_returns_module_with_valid_field_names():
140141
) # Round trip because invalid identifiers get picked up in parse
141142
class_def = get_class_def(parsed, name_filter="CustomQueryCamelCaseQuery")
142143
field_names = get_assignment_target_names(class_def)
143-
assert field_names == {"in_", "field4", "field5", "underscore_named_field_"}
144+
assert field_names == {
145+
"in_",
146+
"field4",
147+
"field5",
148+
"underscore_named_field_",
149+
"schema_",
150+
}

0 commit comments

Comments
 (0)