Skip to content

Commit 65b3410

Browse files
authored
Merge pull request #119 from mirumee/default_values_cycles
Default values cycles
2 parents 24b9968 + 08cba1e commit 65b3410

File tree

9 files changed

+224
-119
lines changed

9 files changed

+224
-119
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- Fixed support of custom operation types names.
88
- Unlocked versions of black, isort, autoflake and dev dependencies
99
- Added `remote_schema_verify_ssl` option.
10+
- Changed how default values for inputs are generated to handle potential cycles.
1011

1112

1213
## 0.4.0 (2023-03-20)

EXAMPLE.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,6 @@ from .base_model import BaseModel
259259
from .enums import Color
260260

261261

262-
class LocationInput(BaseModel):
263-
city: Optional[str]
264-
country: Optional[str]
265-
266-
267262
class UserCreateInput(BaseModel):
268263
first_name: Optional[str] = Field(alias="firstName")
269264
last_name: Optional[str] = Field(alias="lastName")
@@ -272,11 +267,9 @@ class UserCreateInput(BaseModel):
272267
location: Optional["LocationInput"]
273268

274269

275-
class NotificationsPreferencesInput(BaseModel):
276-
receive_mails: bool = Field(alias="receiveMails")
277-
receive_push_notifications: bool = Field(alias="receivePushNotifications")
278-
receive_sms: bool = Field(alias="receiveSms")
279-
title: str
270+
class LocationInput(BaseModel):
271+
city: Optional[str]
272+
country: Optional[str]
280273

281274

282275
class UserPreferencesInput(BaseModel):
@@ -288,7 +281,7 @@ class UserPreferencesInput(BaseModel):
288281
)
289282
notifications_preferences: "NotificationsPreferencesInput" = Field(
290283
alias="notificationsPreferences",
291-
default=NotificationsPreferencesInput.parse_obj(
284+
default_factory=lambda: globals()["NotificationsPreferencesInput"].parse_obj(
292285
{
293286
"receiveMails": True,
294287
"receivePushNotifications": True,
@@ -299,10 +292,17 @@ class UserPreferencesInput(BaseModel):
299292
)
300293

301294

302-
LocationInput.update_forward_refs()
295+
class NotificationsPreferencesInput(BaseModel):
296+
receive_mails: bool = Field(alias="receiveMails")
297+
receive_push_notifications: bool = Field(alias="receivePushNotifications")
298+
receive_sms: bool = Field(alias="receiveSms")
299+
title: str
300+
301+
303302
UserCreateInput.update_forward_refs()
304-
NotificationsPreferencesInput.update_forward_refs()
303+
LocationInput.update_forward_refs()
305304
UserPreferencesInput.update_forward_refs()
305+
NotificationsPreferencesInput.update_forward_refs()
306306
```
307307

308308
### Enums

ariadne_codegen/client_generators/input_fields.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,16 @@
2121

2222
from ..codegen import (
2323
generate_annotation_name,
24+
generate_attribute,
2425
generate_call,
2526
generate_constant,
2627
generate_dict,
2728
generate_keyword,
2829
generate_lambda,
2930
generate_list,
3031
generate_list_annotation,
31-
generate_method_call,
3232
generate_name,
33+
generate_subscript,
3334
)
3435
from ..exceptions import ParsingError
3536
from .constants import ANY, FIELD_CLASS, SIMPLE_TYPE_MAP
@@ -161,7 +162,28 @@ def parse_input_const_value_node(
161162
],
162163
)
163164
if not nested_object:
164-
return generate_method_call(field_type, "parse_obj", [dict_])
165+
return generate_call(
166+
func=generate_name(FIELD_CLASS),
167+
keywords=[
168+
generate_keyword(
169+
arg="default_factory",
170+
value=generate_lambda(
171+
body=generate_call(
172+
func=generate_attribute(
173+
value=generate_subscript(
174+
value=generate_call(
175+
func=generate_name("globals")
176+
),
177+
slice_=generate_constant(field_type),
178+
),
179+
attr="parse_obj",
180+
),
181+
args=[dict_],
182+
)
183+
),
184+
)
185+
],
186+
)
165187
return dict_
166188

167189
return None

ariadne_codegen/client_generators/input_types.py

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,13 @@ def generate(self) -> ast.Module:
8080
names=scalar_data.names_to_import, from_=scalar_data.import_
8181
)
8282
)
83-
sorted_class_defs = self._get_sorted_class_defs()
8483
update_forward_refs_calls = [
8584
generate_expr(generate_method_call(c.name, UPDATE_FORWARD_REFS_METHOD))
86-
for c in sorted_class_defs
85+
for c in self._class_defs
8786
]
8887
module_body = (
8988
cast(List[ast.stmt], self._imports)
90-
+ cast(List[ast.stmt], sorted_class_defs)
89+
+ cast(List[ast.stmt], self._class_defs)
9190
+ cast(List[ast.stmt], update_forward_refs_calls)
9291
)
9392
module = generate_module(body=module_body)
@@ -136,9 +135,7 @@ def _parse_input_definition(
136135
field_implementation, input_field=field, field_name=org_name
137136
)
138137
class_def.body.append(field_implementation)
139-
self._save_used_enums_scalars_and_dependencies(
140-
class_name=class_def.name, field_type=field_type
141-
)
138+
self._save_used_enums_and_scalars(field_type=field_type)
142139

143140
if self.plugin_manager:
144141
class_def = self.plugin_manager.generate_input_class(
@@ -174,40 +171,10 @@ def _process_field_value(
174171
)
175172
return field_with_alias
176173

177-
def _save_used_enums_scalars_and_dependencies(
178-
self, class_name: str, field_type: str = ""
179-
) -> None:
174+
def _save_used_enums_and_scalars(self, field_type: str = "") -> None:
180175
if not field_type:
181176
return
182-
if isinstance(self.schema.type_map[field_type], GraphQLInputObjectType):
183-
self._dependencies[class_name].append(field_type)
184-
elif isinstance(self.schema.type_map[field_type], GraphQLEnumType):
177+
if isinstance(self.schema.type_map[field_type], GraphQLEnumType):
185178
self._used_enums.append(field_type)
186179
elif isinstance(self.schema.type_map[field_type], GraphQLScalarType):
187180
self._used_scalars.append(field_type)
188-
189-
def _get_sorted_class_defs(self) -> List[ast.ClassDef]:
190-
input_class_defs_dict_ = {c.name: c for c in self._class_defs}
191-
192-
processed_names = []
193-
for class_ in self._class_defs:
194-
if class_.name not in processed_names:
195-
processed_names.extend(self._get_dependant_names(class_.name))
196-
processed_names.append(class_.name)
197-
198-
names_without_duplicates = self._get_list_without_duplicates(processed_names)
199-
return [input_class_defs_dict_[n] for n in names_without_duplicates]
200-
201-
def _get_dependant_names(self, name: str) -> List[str]:
202-
result = []
203-
for dependency_name in self._dependencies[name]:
204-
result.extend(self._get_dependant_names(dependency_name))
205-
result.append(dependency_name)
206-
return result
207-
208-
def _get_list_without_duplicates(self, list_: list) -> list:
209-
result = []
210-
for element in list_:
211-
if not element in result:
212-
result.append(element)
213-
return result

tests/client_generators/input_types_generator/test_default_values.py

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -148,19 +148,49 @@ def test_generate_returns_module_with_parsed_inputs_object_field_with_default_va
148148
}
149149
"""
150150
expected_field_value = ast.Call(
151-
func=ast.Attribute(value=ast.Name(id="SecondInput"), attr="parse_obj"),
152-
args=[
153-
ast.Dict(keys=[ast.Constant(value="val")], values=[ast.Constant(value=5)])
151+
func=ast.Name(id="Field"),
152+
args=[],
153+
keywords=[
154+
ast.keyword(
155+
arg="default_factory",
156+
value=ast.Lambda(
157+
args=ast.arguments(
158+
posonlyargs=[],
159+
args=[],
160+
kwonlyargs=[],
161+
kw_defaults=[],
162+
defaults=[],
163+
),
164+
body=ast.Call(
165+
func=ast.Attribute(
166+
value=ast.Subscript(
167+
value=ast.Call(
168+
func=ast.Name(id="globals"), args=[], keywords=[]
169+
),
170+
slice=ast.Constant(value="SecondInput"),
171+
),
172+
attr="parse_obj",
173+
),
174+
args=[
175+
ast.Dict(
176+
keys=[ast.Constant(value="val")],
177+
values=[ast.Constant(value=5)],
178+
)
179+
],
180+
keywords=[],
181+
),
182+
),
183+
)
154184
],
155-
keywords=[],
156185
)
186+
157187
generator = InputTypesGenerator(
158188
schema=build_ast_schema(parse(schema_str)), enums_module="enums"
159189
)
160190

161191
module = generator.generate()
162192

163-
class_def = get_class_def(module, 1)
193+
class_def = get_class_def(module, 0)
164194
assert isinstance(class_def, ast.ClassDef)
165195
assert class_def.name == "TestInput"
166196
assert len(class_def.body) == 1
@@ -185,27 +215,53 @@ def test_generate_returns_module_with_parsed_nested_object_as_default_value():
185215
}
186216
"""
187217
expected_field_value = ast.Call(
188-
func=ast.Attribute(value=ast.Name(id="SecondInput"), attr="parse_obj"),
189-
args=[
190-
ast.Dict(
191-
keys=[ast.Constant(value="nested")],
192-
values=[
193-
ast.Dict(
194-
keys=[ast.Constant(value="val")],
195-
values=[ast.Constant(value=1.5)],
196-
)
197-
],
218+
func=ast.Name(id="Field"),
219+
args=[],
220+
keywords=[
221+
ast.keyword(
222+
arg="default_factory",
223+
value=ast.Lambda(
224+
args=ast.arguments(
225+
posonlyargs=[],
226+
args=[],
227+
kwonlyargs=[],
228+
kw_defaults=[],
229+
defaults=[],
230+
),
231+
body=ast.Call(
232+
func=ast.Attribute(
233+
value=ast.Subscript(
234+
value=ast.Call(
235+
func=ast.Name(id="globals"), args=[], keywords=[]
236+
),
237+
slice=ast.Constant(value="SecondInput"),
238+
),
239+
attr="parse_obj",
240+
),
241+
args=[
242+
ast.Dict(
243+
keys=[ast.Constant(value="nested")],
244+
values=[
245+
ast.Dict(
246+
keys=[ast.Constant(value="val")],
247+
values=[ast.Constant(value=1.5)],
248+
)
249+
],
250+
)
251+
],
252+
keywords=[],
253+
),
254+
),
198255
)
199256
],
200-
keywords=[],
201257
)
202258
generator = InputTypesGenerator(
203259
schema=build_ast_schema(parse(schema_str)), enums_module="enums"
204260
)
205261

206262
module = generator.generate()
207263

208-
class_def = get_class_def(module, 2)
264+
class_def = get_class_def(module, 0)
209265
assert isinstance(class_def, ast.ClassDef)
210266
assert class_def.name == "TestInput"
211267
assert len(class_def.body) == 1

tests/client_generators/input_types_generator/test_method_calls.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_generate_returns_modules_with_update_forward_refs_calls():
3535
ast.Expr(
3636
value=ast.Call(
3737
func=ast.Attribute(
38-
value=ast.Name(id="NestedInput"), attr=UPDATE_FORWARD_REFS_METHOD
38+
value=ast.Name(id="TestInput"), attr=UPDATE_FORWARD_REFS_METHOD
3939
),
4040
args=[],
4141
keywords=[],
@@ -44,7 +44,7 @@ def test_generate_returns_modules_with_update_forward_refs_calls():
4444
ast.Expr(
4545
value=ast.Call(
4646
func=ast.Attribute(
47-
value=ast.Name(id="TestInput"), attr=UPDATE_FORWARD_REFS_METHOD
47+
value=ast.Name(id="NestedInput"), attr=UPDATE_FORWARD_REFS_METHOD
4848
),
4949
args=[],
5050
keywords=[],

tests/client_generators/input_types_generator/test_parsing_inputs.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,34 +25,34 @@
2525
""",
2626
[
2727
ast.ClassDef(
28-
name="CustomInput2",
28+
name="CustomInput",
2929
bases=[ast.Name(id=BASE_MODEL_CLASS_NAME)],
3030
keywords=[],
3131
decorator_list=[],
3232
body=[
3333
ast.AnnAssign(
34-
target=ast.Name(id="field"),
34+
target=ast.Name(id="field1"),
35+
annotation=ast.Name(id='"CustomInput2"'),
36+
simple=1,
37+
),
38+
ast.AnnAssign(
39+
target=ast.Name(id="field2"),
3540
annotation=ast.Name(id="int"),
3641
simple=1,
37-
)
42+
),
3843
],
3944
),
4045
ast.ClassDef(
41-
name="CustomInput",
46+
name="CustomInput2",
4247
bases=[ast.Name(id=BASE_MODEL_CLASS_NAME)],
4348
keywords=[],
4449
decorator_list=[],
4550
body=[
4651
ast.AnnAssign(
47-
target=ast.Name(id="field1"),
48-
annotation=ast.Name(id='"CustomInput2"'),
49-
simple=1,
50-
),
51-
ast.AnnAssign(
52-
target=ast.Name(id="field2"),
52+
target=ast.Name(id="field"),
5353
annotation=ast.Name(id="int"),
5454
simple=1,
55-
),
55+
)
5656
],
5757
),
5858
],
@@ -72,7 +72,7 @@ def test_generate_returns_module_with_parsed_input_types(
7272
assert compare_ast(class_defs, expected_class_defs)
7373

7474

75-
def test_generate_returns_module_with_correct_order_of_classes():
75+
def test_generate_returns_module_with_classes_in_the_same_order_as_declared():
7676
schema_str = """
7777
input BeforeInput {
7878
field: Boolean!
@@ -96,9 +96,9 @@ def test_generate_returns_module_with_correct_order_of_classes():
9696
"""
9797
expected_order = [
9898
"BeforeInput",
99-
"NestedInput",
100-
"SecondInput",
10199
"TestInput",
100+
"SecondInput",
101+
"NestedInput",
102102
"AfterInput",
103103
]
104104
generator = InputTypesGenerator(

0 commit comments

Comments
 (0)