Skip to content

Commit fdc5fb8

Browse files
authored
Merge pull request #278 from mirumee/model_rebuild_calls
add model_rebuild_calls for result_types.py
2 parents b5a67e1 + 6cf3077 commit fdc5fb8

File tree

36 files changed

+142
-2
lines changed

36 files changed

+142
-2
lines changed

ariadne_codegen/client_generators/result_types.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
generate_ann_assign,
3535
generate_class_def,
3636
generate_constant,
37+
generate_expr,
3738
generate_import_from,
39+
generate_method_call,
3840
generate_module,
3941
generate_pass,
4042
generate_pydantic_field,
@@ -55,6 +57,7 @@
5557
MIXIN_FROM_NAME,
5658
MIXIN_IMPORT_NAME,
5759
MIXIN_NAME,
60+
MODEL_REBUILD_METHOD,
5861
OPTIONAL,
5962
PYDANTIC_MODULE,
6063
TYPENAME_ALIAS,
@@ -152,8 +155,16 @@ def _get_operation_type_name(self, definition: ExecutableDefinitionNode) -> str:
152155
raise NotSupported(f"Not supported operation type: {definition}")
153156

154157
def generate(self) -> ast.Module:
155-
module_body = cast(List[ast.stmt], self._imports) + cast(
156-
List[ast.stmt], self._class_defs
158+
model_rebuild_calls = [
159+
generate_expr(generate_method_call(class_def.name, MODEL_REBUILD_METHOD))
160+
for class_def in self._class_defs
161+
if self.include_model_rebuild(class_def)
162+
]
163+
164+
module_body = (
165+
cast(List[ast.stmt], self._imports)
166+
+ cast(List[ast.stmt], self._class_defs)
167+
+ cast(List[ast.stmt], model_rebuild_calls)
157168
)
158169

159170
module = generate_module(module_body)
@@ -163,6 +174,11 @@ def generate(self) -> ast.Module:
163174
)
164175
return module
165176

177+
def include_model_rebuild(self, class_def: ast.ClassDef) -> bool:
178+
visitor = ClassDefNamesVisitor()
179+
visitor.visit(class_def)
180+
return visitor.found_name_with_quote
181+
166182
def get_imports(self) -> List[ast.ImportFrom]:
167183
return self._imports
168184

@@ -560,3 +576,19 @@ def enter_field(node: FieldNode, *_args: Any) -> FieldNode:
560576
copied_node = deepcopy(node)
561577
visit(copied_node, RemoveMixinVisitor())
562578
return copied_node
579+
580+
581+
class ClassDefNamesVisitor(ast.NodeVisitor):
582+
def __init__(self):
583+
self.found_name_with_quote = False
584+
585+
def visit_Name(self, node): # pylint: disable=C0103
586+
if '"' in node.id:
587+
self.found_name_with_quote = True
588+
self.generic_visit(node)
589+
590+
def visit_Subscript(self, node): # pylint: disable=C0103
591+
if isinstance(node.value, ast.Name) and node.value.id == "Literal":
592+
return
593+
594+
self.generic_visit(node)

tests/main/clients/custom_base_client/expected_client/get_query_a.py

+3
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ class GetQueryA(BaseModel):
99

1010
class GetQueryAQueryA(BaseModel):
1111
field_a: int = Field(alias="fieldA")
12+
13+
14+
GetQueryA.model_rebuild()

tests/main/clients/custom_files_names/expected_client/get_query_a.py

+3
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ class GetQueryA(BaseModel):
99

1010
class GetQueryAQueryA(BaseModel):
1111
field_a: int = Field(alias="fieldA")
12+
13+
14+
GetQueryA.model_rebuild()

tests/main/clients/custom_scalars/expected_client/get_a.py

+3
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,6 @@ class GetATestQuery(BaseModel):
1616
code: Annotated[Code, BeforeValidator(parse_code)]
1717
id: int
1818
other: Any
19+
20+
21+
GetA.model_rebuild()

tests/main/clients/example/expected_client/create_user.py

+3
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ class CreateUser(BaseModel):
1111

1212
class CreateUserUserCreate(BaseModel):
1313
id: str
14+
15+
16+
CreateUser.model_rebuild()

tests/main/clients/example/expected_client/list_all_users.py

+4
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@ class ListAllUsersUsers(BaseModel):
1919

2020
class ListAllUsersUsersLocation(BaseModel):
2121
country: Optional[str]
22+
23+
24+
ListAllUsers.model_rebuild()
25+
ListAllUsersUsers.model_rebuild()

tests/main/clients/example/expected_client/list_users_by_country.py

+3
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@ class ListUsersByCountry(BaseModel):
1313

1414
class ListUsersByCountryUsers(BasicUser, UserPersonalData):
1515
favourite_color: Optional[Color] = Field(alias="favouriteColor")
16+
17+
18+
ListUsersByCountry.model_rebuild()

tests/main/clients/extended_models/expected_client/fragments_with_mixins.py

+3
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,6 @@ class FragmentsWithMixinsQueryA(FragmentA, CommonMixin):
1616

1717
class FragmentsWithMixinsQueryB(FragmentB, CommonMixin):
1818
pass
19+
20+
21+
FragmentsWithMixins.model_rebuild()

tests/main/clients/extended_models/expected_client/get_query_a.py

+3
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ class GetQueryA(BaseModel):
1111

1212
class GetQueryAQueryA(BaseModel, MixinA, CommonMixin):
1313
field_a: int = Field(alias="fieldA")
14+
15+
16+
GetQueryA.model_rebuild()

tests/main/clients/extended_models/expected_client/get_query_b.py

+3
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ class GetQueryB(BaseModel):
1111

1212
class GetQueryBQueryB(BaseModel, MixinB, CommonMixin):
1313
field_b: str = Field(alias="fieldB")
14+
15+
16+
GetQueryB.model_rebuild()

tests/main/clients/fragments_on_abstract_types/expected_client/query_with_fragment_on_sub_interface.py

+3
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ class QueryWithFragmentOnSubInterfaceQueryInterfaceBaseInterface(BaseModel):
1919

2020
class QueryWithFragmentOnSubInterfaceQueryInterfaceInterfaceA(FragmentA):
2121
typename__: Literal["InterfaceA"] = Field(alias="__typename")
22+
23+
24+
QueryWithFragmentOnSubInterface.model_rebuild()

tests/main/clients/fragments_on_abstract_types/expected_client/query_with_fragment_on_sub_interface_with_inline_fragment.py

+3
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,6 @@ class QueryWithFragmentOnSubInterfaceWithInlineFragmentQueryInterfaceTypeA(BaseM
3232
id: str
3333
value_a: str = Field(alias="valueA")
3434
another: str
35+
36+
37+
QueryWithFragmentOnSubInterfaceWithInlineFragment.model_rebuild()

tests/main/clients/fragments_on_abstract_types/expected_client/query_with_fragment_on_union_member.py

+3
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ class QueryWithFragmentOnUnionMemberQueryUnionTypeA(BaseModel):
1919

2020
class QueryWithFragmentOnUnionMemberQueryUnionTypeB(FragmentB):
2121
typename__: Literal["TypeB"] = Field(alias="__typename")
22+
23+
24+
QueryWithFragmentOnUnionMember.model_rebuild()

tests/main/clients/inline_fragments/expected_client/interface_a.py

+3
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,6 @@ class InterfaceAQueryITypeB(BaseModel):
2626
typename__: Literal["TypeB"] = Field(alias="__typename")
2727
id: str
2828
field_b: str = Field(alias="fieldB")
29+
30+
31+
InterfaceA.model_rebuild()

tests/main/clients/inline_fragments/expected_client/interface_b.py

+3
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@ class InterfaceBQueryITypeA(BaseModel):
2020
typename__: Literal["TypeA"] = Field(alias="__typename")
2121
id: str
2222
field_a: str = Field(alias="fieldA")
23+
24+
25+
InterfaceB.model_rebuild()

tests/main/clients/inline_fragments/expected_client/interface_c.py

+3
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ class InterfaceCQueryI(BaseModel):
1414
alias="__typename"
1515
)
1616
id: str
17+
18+
19+
InterfaceC.model_rebuild()

tests/main/clients/inline_fragments/expected_client/interface_with_typename.py

+3
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ class InterfaceWithTypenameQueryI(BaseModel):
1414
alias="__typename"
1515
)
1616
id: str
17+
18+
19+
InterfaceWithTypename.model_rebuild()

tests/main/clients/inline_fragments/expected_client/list_interface.py

+3
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,6 @@ class ListInterfaceQueryListITypeB(BaseModel):
3333
typename__: Literal["TypeB"] = Field(alias="__typename")
3434
id: str
3535
field_b: str = Field(alias="fieldB")
36+
37+
38+
ListInterface.model_rebuild()

tests/main/clients/inline_fragments/expected_client/list_union.py

+3
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,6 @@ class ListUnionQueryListUTypeB(BaseModel):
3232

3333
class ListUnionQueryListUTypeC(BaseModel):
3434
typename__: Literal["TypeC"] = Field(alias="__typename")
35+
36+
37+
ListUnion.model_rebuild()

tests/main/clients/inline_fragments/expected_client/query_with_fragment_on_interface.py

+3
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,6 @@ class QueryWithFragmentOnInterfaceQueryITypeB(BaseModel):
2828
typename__: Literal["TypeB"] = Field(alias="__typename")
2929
id: str
3030
field_b: str = Field(alias="fieldB")
31+
32+
33+
QueryWithFragmentOnInterface.model_rebuild()

tests/main/clients/inline_fragments/expected_client/query_with_fragment_on_union.py

+3
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,6 @@ class QueryWithFragmentOnUnionQueryUTypeB(BaseModel):
2727

2828
class QueryWithFragmentOnUnionQueryUTypeC(BaseModel):
2929
typename__: Literal["TypeC"] = Field(alias="__typename")
30+
31+
32+
QueryWithFragmentOnUnion.model_rebuild()

tests/main/clients/inline_fragments/expected_client/union_a.py

+3
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ class UnionAQueryUTypeB(BaseModel):
2525

2626
class UnionAQueryUTypeC(BaseModel):
2727
typename__: Literal["TypeC"] = Field(alias="__typename")
28+
29+
30+
UnionA.model_rebuild()

tests/main/clients/inline_fragments/expected_client/union_b.py

+3
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,6 @@ class UnionBQueryUTypeB(BaseModel):
2323

2424
class UnionBQueryUTypeC(BaseModel):
2525
typename__: Literal["TypeC"] = Field(alias="__typename")
26+
27+
28+
UnionB.model_rebuild()

tests/main/clients/multiple_fragments/expected_client/example_query_1.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ class ExampleQuery1(BaseModel):
1010

1111
class ExampleQuery1ExampleQuery(MinimalA):
1212
value: str
13+
14+
15+
ExampleQuery1.model_rebuild()

tests/main/clients/multiple_fragments/expected_client/example_query_2.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ class ExampleQuery2(BaseModel):
1010

1111
class ExampleQuery2ExampleQuery(FullA):
1212
pass
13+
14+
15+
ExampleQuery2.model_rebuild()

tests/main/clients/multiple_fragments/expected_client/example_query_3.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ class ExampleQuery3(BaseModel):
1010

1111
class ExampleQuery3ExampleQuery(CompleteA):
1212
pass
13+
14+
15+
ExampleQuery3.model_rebuild()

tests/main/clients/only_used_inputs_and_enums/expected_client/get_f.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ class GetF(BaseModel):
88

99
class GetFF(BaseModel):
1010
val: EnumF
11+
12+
13+
GetF.model_rebuild()

tests/main/clients/only_used_inputs_and_enums/expected_client/get_g.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ class GetG(BaseModel):
88

99
class GetGG(FragmentG):
1010
pass
11+
12+
13+
GetG.model_rebuild()

tests/main/clients/operations/expected_client/get_a.py

+4
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,7 @@ class GetAA(BaseModel):
1414

1515
class GetAAValueB(BaseModel):
1616
value: str
17+
18+
19+
GetA.model_rebuild()
20+
GetAA.model_rebuild()

tests/main/clients/operations/expected_client/get_a_with_fragment.py

+4
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ class GetAWithFragmentA(BaseModel):
1515

1616
class GetAWithFragmentAValueB(FragmentB):
1717
pass
18+
19+
20+
GetAWithFragment.model_rebuild()
21+
GetAWithFragmentA.model_rebuild()

tests/main/clients/operations/expected_client/get_s.py

+3
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ class GetS(BaseModel):
77

88
class GetSS(BaseModel):
99
id: int
10+
11+
12+
GetS.model_rebuild()

tests/main/clients/operations/expected_client/get_xyz.py

+3
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,6 @@ class GetXYZXyzTypeY(FragmentY):
2323

2424
class GetXYZXyzTypeZ(BaseModel):
2525
typename__: Literal["TypeZ"] = Field(alias="__typename")
26+
27+
28+
GetXYZ.model_rebuild()

tests/main/clients/shorter_results/expected_client/get_animal_by_name.py

+3
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,6 @@ class GetAnimalByNameAnimalByNameDog(BaseModel):
2828
typename__: Literal["Dog"] = Field(alias="__typename")
2929
name: str
3030
puppies: int
31+
32+
33+
GetAnimalByName.model_rebuild()

tests/main/clients/shorter_results/expected_client/get_authenticated_user.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ class GetAuthenticatedUser(BaseModel):
88
class GetAuthenticatedUserMe(BaseModel):
99
id: str
1010
username: str
11+
12+
13+
GetAuthenticatedUser.model_rebuild()

tests/main/clients/shorter_results/expected_client/list_animals.py

+3
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,6 @@ class ListAnimalsListAnimalsDog(BaseModel):
3333
typename__: Literal["Dog"] = Field(alias="__typename")
3434
name: str
3535
puppies: int
36+
37+
38+
ListAnimals.model_rebuild()

tests/main/clients/shorter_results/expected_client/list_type_a.py

+3
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@ class ListTypeA(BaseModel):
1313

1414
class ListTypeAListOptionalTypeA(BaseModel):
1515
id: int
16+
17+
18+
ListTypeA.model_rebuild()

0 commit comments

Comments
 (0)