Skip to content

Commit c14dd92

Browse files
authored
Merge pull request #283 from bombsimon/fix/input-type-model-rebuild
fix: Include `model_rebuild` for input types as well
2 parents ba664b5 + d1f89c1 commit c14dd92

File tree

6 files changed

+52
-24
lines changed

6 files changed

+52
-24
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# CHANGELOG
22

3+
## 0.14.0 (Unreleased)
4+
5+
- Re-added `model_rebuild` calls for input types with forward references.
6+
7+
38
## 0.13.0 (2024-03-4)
49

510
- Fixed `str_to_snake_case` utility to capture fully capitalized words followed by an underscore.

ariadne_codegen/client_generators/input_types.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
generate_ann_assign,
1414
generate_class_def,
1515
generate_constant,
16+
generate_expr,
1617
generate_import_from,
1718
generate_keyword,
19+
generate_method_call,
1820
generate_module,
1921
generate_pydantic_field,
22+
model_has_forward_refs,
2023
)
2124
from ..plugins.manager import PluginManager
2225
from ..utils import process_name
@@ -28,6 +31,7 @@
2831
BASE_MODEL_IMPORT,
2932
FIELD_CLASS,
3033
LIST,
34+
MODEL_REBUILD_METHOD,
3135
OPTIONAL,
3236
PLAIN_SERIALIZER,
3337
PYDANTIC_MODULE,
@@ -85,8 +89,16 @@ def generate(self, types_to_include: Optional[List[str]] = None) -> ast.Module:
8589
scalar_data = self.custom_scalars[scalar_name]
8690
self._imports.extend(generate_scalar_imports(scalar_data))
8791

88-
module_body = cast(List[ast.stmt], self._imports) + cast(
89-
List[ast.stmt], class_defs
92+
model_rebuild_calls = [
93+
generate_expr(generate_method_call(class_def.name, MODEL_REBUILD_METHOD))
94+
for class_def in class_defs
95+
if model_has_forward_refs(class_def)
96+
]
97+
98+
module_body = (
99+
cast(List[ast.stmt], self._imports)
100+
+ cast(List[ast.stmt], class_defs)
101+
+ cast(List[ast.stmt], model_rebuild_calls)
90102
)
91103
module = generate_module(body=module_body)
92104

ariadne_codegen/client_generators/result_types.py

+2-22
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
generate_module,
4141
generate_pass,
4242
generate_pydantic_field,
43+
model_has_forward_refs,
4344
)
4445
from ..exceptions import NotSupported, ParsingError
4546
from ..plugins.manager import PluginManager
@@ -158,7 +159,7 @@ def generate(self) -> ast.Module:
158159
model_rebuild_calls = [
159160
generate_expr(generate_method_call(class_def.name, MODEL_REBUILD_METHOD))
160161
for class_def in self._class_defs
161-
if self.include_model_rebuild(class_def)
162+
if model_has_forward_refs(class_def)
162163
]
163164

164165
module_body = (
@@ -174,11 +175,6 @@ def generate(self) -> ast.Module:
174175
)
175176
return module
176177

177-
def include_model_rebuild(self, class_def: ast.ClassDef) -> bool:
178-
visitor = ClassDefNamesVisitor()
179-
visitor.visit(class_def)
180-
return visitor.found_name_with_quote
181-
182178
def get_imports(self) -> List[ast.ImportFrom]:
183179
return self._imports
184180

@@ -576,19 +572,3 @@ def enter_field(node: FieldNode, *_args: Any) -> FieldNode:
576572
copied_node = deepcopy(node)
577573
visit(copied_node, RemoveMixinVisitor())
578574
return copied_node
579-
580-
581-
class ClassDefNamesVisitor(ast.NodeVisitor):
582-
def __init__(self):
583-
self.found_name_with_quote = False
584-
585-
def visit_Name(self, node): # pylint: disable=C0103
586-
if '"' in node.id:
587-
self.found_name_with_quote = True
588-
self.generic_visit(node)
589-
590-
def visit_Subscript(self, node): # pylint: disable=C0103
591-
if isinstance(node.value, ast.Name) and node.value.id == "Literal":
592-
return
593-
594-
self.generic_visit(node)

ariadne_codegen/codegen.py

+22
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,25 @@ def generate_yield(value: Optional[ast.expr] = None) -> ast.Yield:
332332

333333
def generate_pass() -> ast.Pass:
334334
return ast.Pass()
335+
336+
337+
def model_has_forward_refs(class_def: ast.ClassDef) -> bool:
338+
visitor = ClassDefNamesVisitor()
339+
visitor.visit(class_def)
340+
return visitor.found_name_with_quote
341+
342+
343+
class ClassDefNamesVisitor(ast.NodeVisitor):
344+
def __init__(self):
345+
self.found_name_with_quote = False
346+
347+
def visit_Name(self, node): # pylint: disable=C0103
348+
if '"' in node.id:
349+
self.found_name_with_quote = True
350+
self.generic_visit(node)
351+
352+
def visit_Subscript(self, node): # pylint: disable=C0103
353+
if isinstance(node.value, ast.Name) and node.value.id == "Literal":
354+
return
355+
356+
self.generic_visit(node)

tests/main/clients/example/expected_client/input_types.py

+4
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,7 @@ class NotificationsPreferencesInput(BaseModel):
4646
receive_push_notifications: bool = Field(alias="receivePushNotifications")
4747
receive_sms: bool = Field(alias="receiveSms")
4848
title: str
49+
50+
51+
UserCreateInput.model_rebuild()
52+
UserPreferencesInput.model_rebuild()

tests/main/clients/only_used_inputs_and_enums/expected_client/input_types.py

+5
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,8 @@ class InputAB(BaseModel):
2626

2727
class InputE(BaseModel):
2828
val: EnumE
29+
30+
31+
InputA.model_rebuild()
32+
InputAA.model_rebuild()
33+
InputAB.model_rebuild()

0 commit comments

Comments
 (0)