Skip to content

Commit 44f9740

Browse files
authored
Merge pull request #123 from mirumee/fix_decode_and_parse_application
Change BaseModel to apply parse and serialize methods on every list element
2 parents 65b3410 + 292a2d4 commit 44f9740

File tree

11 files changed

+503
-81
lines changed

11 files changed

+503
-81
lines changed

Diff for: CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- Unlocked versions of black, isort, autoflake and dev dependencies
99
- Added `remote_schema_verify_ssl` option.
1010
- Changed how default values for inputs are generated to handle potential cycles.
11+
- Fixed `BaseModel` incorrectly calling `parse` and `serialize` methods on entire list instead of its items for `List[Scalar]`.
1112

1213

1314
## 0.4.0 (2023-03-20)
+29-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, Type, Union, get_args, get_origin
22

33
from pydantic import BaseModel as PydanticBaseModel
44
from pydantic.class_validators import validator
@@ -15,16 +15,36 @@ class Config:
1515

1616
# pylint: disable=no-self-argument
1717
@validator("*", pre=True)
18-
def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any:
19-
decode = SCALARS_PARSE_FUNCTIONS.get(field.type_)
20-
if decode and callable(decode):
18+
def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any:
19+
return cls._parse_custom_scalar_value(value, field.annotation)
20+
21+
@classmethod
22+
def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any:
23+
origin = get_origin(type_)
24+
args = get_args(type_)
25+
if origin is list and isinstance(value, list):
26+
return [cls._parse_custom_scalar_value(item, args[0]) for item in value]
27+
28+
if origin is Union and type(None) in args:
29+
sub_type: Any = list(filter(None, args))[0]
30+
return cls._parse_custom_scalar_value(value, sub_type)
31+
32+
decode = SCALARS_PARSE_FUNCTIONS.get(type_)
33+
if value and decode and callable(decode):
2134
return decode(value)
35+
2236
return value
2337

2438
def dict(self, **kwargs: Any) -> Dict[str, Any]:
2539
dict_ = super().dict(**kwargs)
26-
for key, value in dict_.items():
27-
serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value))
28-
if serialize and callable(serialize):
29-
dict_[key] = serialize(value)
30-
return dict_
40+
return {key: self._serialize_value(value) for key, value in dict_.items()}
41+
42+
def _serialize_value(self, value: Any) -> Any:
43+
serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value))
44+
if serialize and callable(serialize):
45+
return serialize(value)
46+
47+
if isinstance(value, list):
48+
return [self._serialize_value(item) for item in value]
49+
50+
return value
+241
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
from typing import List, Optional
2+
3+
import pytest
4+
5+
from ariadne_codegen.client_generators.dependencies.base_model import BaseModel
6+
7+
8+
@pytest.mark.parametrize(
9+
"annotation, value, expected_args",
10+
[
11+
(str, "a", {"a"}),
12+
(Optional[str], "a", {"a"}),
13+
(Optional[str], None, set()),
14+
(List[str], ["a", "b"], {"a", "b"}),
15+
(List[Optional[str]], ["a", None], {"a"}),
16+
(Optional[List[str]], ["a", "b"], {"a", "b"}),
17+
(Optional[List[str]], None, set()),
18+
(Optional[List[Optional[str]]], ["a", None], {"a"}),
19+
(Optional[List[Optional[str]]], None, set()),
20+
(List[List[str]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}),
21+
(Optional[List[List[str]]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}),
22+
(Optional[List[List[str]]], None, set()),
23+
(
24+
Optional[List[Optional[List[str]]]],
25+
[["a", "b"], ["c", "d"]],
26+
{"a", "b", "c", "d"},
27+
),
28+
(Optional[List[Optional[List[str]]]], None, set()),
29+
(Optional[List[Optional[List[str]]]], [["a", "b"], None], {"a", "b"}),
30+
(
31+
Optional[List[Optional[List[Optional[str]]]]],
32+
[["a", "b"], ["c", "d"]],
33+
{"a", "b", "c", "d"},
34+
),
35+
(Optional[List[Optional[List[Optional[str]]]]], None, set()),
36+
(Optional[List[Optional[List[Optional[str]]]]], [["a", "b"], None], {"a", "b"}),
37+
(
38+
Optional[List[Optional[List[Optional[str]]]]],
39+
[["a", None], ["b", None]],
40+
{"a", "b"},
41+
),
42+
],
43+
)
44+
def test_parse_obj_applies_parse_on_every_list_element(
45+
annotation, value, expected_args, mocker
46+
):
47+
mocked_parse = mocker.MagicMock(side_effect=lambda s: s)
48+
mocker.patch(
49+
"ariadne_codegen.client_generators.dependencies.base_model."
50+
"SCALARS_PARSE_FUNCTIONS",
51+
{str: mocked_parse},
52+
)
53+
54+
class TestModel(BaseModel):
55+
field: annotation
56+
57+
TestModel.parse_obj({"field": value})
58+
59+
assert mocked_parse.call_count == len(expected_args)
60+
assert {c.args[0] for c in mocked_parse.call_args_list} == expected_args
61+
62+
63+
def test_parse_obj_doesnt_apply_parse_on_not_matching_type(mocker):
64+
mocked_parse = mocker.MagicMock(side_effect=lambda s: s)
65+
mocker.patch(
66+
"ariadne_codegen.client_generators.dependencies.base_model."
67+
"SCALARS_PARSE_FUNCTIONS",
68+
{str: mocked_parse},
69+
)
70+
71+
class TestModel(BaseModel):
72+
field_a: int
73+
field_b: Optional[int]
74+
field_c: Optional[int]
75+
field_d: List[int]
76+
field_e: Optional[List[int]]
77+
field_f: Optional[List[int]]
78+
field_g: Optional[List[Optional[int]]]
79+
field_h: Optional[List[Optional[int]]]
80+
field_i: Optional[List[Optional[int]]]
81+
82+
TestModel.parse_obj(
83+
{
84+
"field_a": 1,
85+
"field_b": 2,
86+
"field_c": None,
87+
"field_d": [3, 4],
88+
"field_e": [5, 6],
89+
"field_f": None,
90+
"field_g": [7, 8],
91+
"field_h": [9, None],
92+
"field_i": None,
93+
}
94+
)
95+
96+
assert not mocked_parse.called
97+
98+
99+
def test_parse_obj_applies_parse_only_once_for_every_element(mocker):
100+
mocked_parse = mocker.MagicMock(side_effect=lambda s: s)
101+
mocker.patch(
102+
"ariadne_codegen.client_generators.dependencies.base_model."
103+
"SCALARS_PARSE_FUNCTIONS",
104+
{str: mocked_parse},
105+
)
106+
107+
class TestModelC(BaseModel):
108+
value: str
109+
110+
class TestModelB(BaseModel):
111+
value: str
112+
field_c: TestModelC
113+
114+
class TestModelA(BaseModel):
115+
value: str
116+
field_b: TestModelB
117+
118+
TestModelA.parse_obj(
119+
{"value": "a", "field_b": {"value": "b", "field_c": {"value": "c"}}}
120+
)
121+
122+
assert mocked_parse.call_count == 3
123+
assert {c.args[0] for c in mocked_parse.call_args_list} == {"a", "b", "c"}
124+
125+
126+
@pytest.mark.parametrize(
127+
"annotation, value, expected_args",
128+
[
129+
(str, "a", {"a"}),
130+
(Optional[str], "a", {"a"}),
131+
(Optional[str], None, set()),
132+
(List[str], ["a", "b"], {"a", "b"}),
133+
(List[Optional[str]], ["a", None], {"a"}),
134+
(Optional[List[str]], ["a", "b"], {"a", "b"}),
135+
(Optional[List[str]], None, set()),
136+
(Optional[List[Optional[str]]], ["a", None], {"a"}),
137+
(Optional[List[Optional[str]]], None, set()),
138+
(List[List[str]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}),
139+
(Optional[List[List[str]]], [["a", "b"], ["c", "d"]], {"a", "b", "c", "d"}),
140+
(Optional[List[List[str]]], None, set()),
141+
(
142+
Optional[List[Optional[List[str]]]],
143+
[["a", "b"], ["c", "d"]],
144+
{"a", "b", "c", "d"},
145+
),
146+
(Optional[List[Optional[List[str]]]], None, set()),
147+
(Optional[List[Optional[List[str]]]], [["a", "b"], None], {"a", "b"}),
148+
(
149+
Optional[List[Optional[List[Optional[str]]]]],
150+
[["a", "b"], ["c", "d"]],
151+
{"a", "b", "c", "d"},
152+
),
153+
(Optional[List[Optional[List[Optional[str]]]]], None, set()),
154+
(Optional[List[Optional[List[Optional[str]]]]], [["a", "b"], None], {"a", "b"}),
155+
(
156+
Optional[List[Optional[List[Optional[str]]]]],
157+
[["a", None], ["b", None]],
158+
{"a", "b"},
159+
),
160+
],
161+
)
162+
def test_dict_applies_serialize_on_every_list_element(
163+
annotation, value, expected_args, mocker
164+
):
165+
mocked_serialize = mocker.MagicMock(side_effect=lambda s: s)
166+
mocker.patch(
167+
"ariadne_codegen.client_generators.dependencies.base_model."
168+
"SCALARS_SERIALIZE_FUNCTIONS",
169+
{str: mocked_serialize},
170+
)
171+
172+
class TestModel(BaseModel):
173+
field: annotation
174+
175+
TestModel.parse_obj({"field": value}).dict()
176+
177+
assert mocked_serialize.call_count == len(expected_args)
178+
assert {c.args[0] for c in mocked_serialize.call_args_list} == expected_args
179+
180+
181+
def test_dict_doesnt_apply_serialize_on_not_matching_type(mocker):
182+
mocked_serialize = mocker.MagicMock(side_effect=lambda s: s)
183+
mocker.patch(
184+
"ariadne_codegen.client_generators.dependencies.base_model."
185+
"SCALARS_SERIALIZE_FUNCTIONS",
186+
{str: mocked_serialize},
187+
)
188+
189+
class TestModel(BaseModel):
190+
field_a: int
191+
field_b: Optional[int]
192+
field_c: Optional[int]
193+
field_d: List[int]
194+
field_e: Optional[List[int]]
195+
field_f: Optional[List[int]]
196+
field_g: Optional[List[Optional[int]]]
197+
field_h: Optional[List[Optional[int]]]
198+
field_i: Optional[List[Optional[int]]]
199+
200+
TestModel.parse_obj(
201+
{
202+
"field_a": 1,
203+
"field_b": 2,
204+
"field_c": None,
205+
"field_d": [3, 4],
206+
"field_e": [5, 6],
207+
"field_f": None,
208+
"field_g": [7, 8],
209+
"field_h": [9, None],
210+
"field_i": None,
211+
}
212+
).dict()
213+
214+
assert not mocked_serialize.called
215+
216+
217+
def test_dict_applies_serialize_only_once_for_every_element(mocker):
218+
mocked_serialize = mocker.MagicMock(side_effect=lambda s: s)
219+
mocker.patch(
220+
"ariadne_codegen.client_generators.dependencies.base_model."
221+
"SCALARS_SERIALIZE_FUNCTIONS",
222+
{str: mocked_serialize},
223+
)
224+
225+
class TestModelC(BaseModel):
226+
value: str
227+
228+
class TestModelB(BaseModel):
229+
value: str
230+
field_c: TestModelC
231+
232+
class TestModelA(BaseModel):
233+
value: str
234+
field_b: TestModelB
235+
236+
TestModelA.parse_obj(
237+
{"value": "a", "field_b": {"value": "b", "field_c": {"value": "c"}}}
238+
).dict()
239+
240+
assert mocked_serialize.call_count == 3
241+
assert {c.args[0] for c in mocked_serialize.call_args_list} == {"a", "b", "c"}
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, Type, Union, get_args, get_origin
22

33
from pydantic import BaseModel as PydanticBaseModel
44
from pydantic.class_validators import validator
@@ -15,16 +15,36 @@ class Config:
1515

1616
# pylint: disable=no-self-argument
1717
@validator("*", pre=True)
18-
def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any:
19-
decode = SCALARS_PARSE_FUNCTIONS.get(field.type_)
20-
if decode and callable(decode):
18+
def parse_custom_scalars(cls, value: Any, field: ModelField) -> Any:
19+
return cls._parse_custom_scalar_value(value, field.annotation)
20+
21+
@classmethod
22+
def _parse_custom_scalar_value(cls, value: Any, type_: Type[Any]) -> Any:
23+
origin = get_origin(type_)
24+
args = get_args(type_)
25+
if origin is list and isinstance(value, list):
26+
return [cls._parse_custom_scalar_value(item, args[0]) for item in value]
27+
28+
if origin is Union and type(None) in args:
29+
sub_type: Any = list(filter(None, args))[0]
30+
return cls._parse_custom_scalar_value(value, sub_type)
31+
32+
decode = SCALARS_PARSE_FUNCTIONS.get(type_)
33+
if value and decode and callable(decode):
2134
return decode(value)
35+
2236
return value
2337

2438
def dict(self, **kwargs: Any) -> Dict[str, Any]:
2539
dict_ = super().dict(**kwargs)
26-
for key, value in dict_.items():
27-
serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value))
28-
if serialize and callable(serialize):
29-
dict_[key] = serialize(value)
30-
return dict_
40+
return {key: self._serialize_value(value) for key, value in dict_.items()}
41+
42+
def _serialize_value(self, value: Any) -> Any:
43+
serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value))
44+
if serialize and callable(serialize):
45+
return serialize(value)
46+
47+
if isinstance(value, list):
48+
return [self._serialize_value(item) for item in value]
49+
50+
return value

0 commit comments

Comments
 (0)