Skip to content

Commit 9bf6fb4

Browse files
committed
fix: fix tests with ai
1 parent 182fb48 commit 9bf6fb4

8 files changed

Lines changed: 111 additions & 59 deletions

File tree

src/inline_snapshot/_customize/_builder.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
from __future__ import annotations
22

3+
import ast
34
from dataclasses import dataclass
45
from typing import Any
56
from typing import Callable
67

78
from inline_snapshot._adapter_context import AdapterContext
9+
from inline_snapshot._code_repr import HasRepr
810
from inline_snapshot._code_repr import mock_repr
11+
from inline_snapshot._code_repr import value_code_repr
912
from inline_snapshot._compare_context import compare_context
1013
from inline_snapshot._customize._custom_sequence import CustomSequence
1114
from inline_snapshot._exceptions import UsageError
15+
from inline_snapshot._utils import clone
1216

1317
from ._custom import Custom
1418
from ._custom_call import CustomCall
@@ -51,6 +55,14 @@ def _get_handler(self, v, snapshot_value=None) -> Custom:
5155

5256
from inline_snapshot._global_state import state
5357

58+
if isinstance(v, Custom):
59+
original_value = v._eval()
60+
else:
61+
try:
62+
original_value = clone(v)
63+
except UsageError:
64+
original_value = v
65+
5466
result = v
5567

5668
while not isinstance(result, Custom):
@@ -63,11 +75,22 @@ def _get_handler(self, v, snapshot_value=None) -> Custom:
6375
snapshot_value=snapshot_value,
6476
)
6577
if r is None:
66-
result = CustomCode(result)
78+
79+
with mock_repr(self._snapshot_context):
80+
repr_str = value_code_repr(result)
81+
82+
try:
83+
ast.parse(repr_str)
84+
except SyntaxError:
85+
result = self.create_call(HasRepr, [type(result), repr_str])
86+
# self.repr_str = HasRepr(type(value), self.repr_str).__repr__()
87+
# self._imports.append(ImportFrom("inline_snapshot", "HasRepr"))
88+
else:
89+
result = CustomCode(result, repr_str)
6790
else:
6891
result = r
6992

70-
result.__dict__["original_value"] = v._eval() if isinstance(v, Custom) else v
93+
result.__dict__["original_value"] = original_value
7194

7295
if not isinstance(v, Custom) and self._build_new_value:
7396
is_same = False
@@ -79,14 +102,14 @@ def _get_handler(self, v, snapshot_value=None) -> Custom:
79102
):
80103
is_same = True
81104

82-
if not is_same and v_eval == v:
105+
if not is_same and v_eval == original_value:
83106
is_same = True
84107

85108
if not is_same:
86109
raise UsageError(f"""\
87110
Customized value does not match original value:
88111
89-
original_value={v!r}
112+
original_value={original_value!r}
90113
91114
customized_value={result._eval()!r}
92115
customized_representation={result!r}
@@ -95,31 +118,43 @@ def _get_handler(self, v, snapshot_value=None) -> Custom:
95118
return result
96119

97120
def _customize(self, value, snapshot_value=missing):
98-
with mock_repr(self._snapshot_context):
99-
return self._get_handler(value, snapshot_value)
121+
return self._get_handler(value, snapshot_value)
100122

101123
def _customize_all(self, value):
102124
if not isinstance(value, Custom):
103125
value = self._customize(value)
104126

127+
def with_original(new_value: Custom, old_value: Custom) -> Custom:
128+
new_value.__dict__["original_value"] = getattr(
129+
old_value, "original_value", old_value._eval()
130+
)
131+
return new_value
132+
105133
if isinstance(value, CustomSequence):
106-
value.value = [self._customize_all(c) for c in value.value]
134+
return with_original(
135+
type(value)([self._customize_all(c) for c in value.value]), value
136+
)
107137
elif isinstance(value, CustomDict):
108-
value.value = {
109-
self._customize_all(k): self._customize_all(v)
110-
for k, v in value.value.items()
111-
}
138+
return with_original(
139+
CustomDict(
140+
{
141+
self._customize_all(k): self._customize_all(v)
142+
for k, v in value.value.items()
143+
}
144+
),
145+
value,
146+
)
112147
elif isinstance(value, CustomCall):
113-
value._function = self._customize_all(value._function)
114-
value._args = [self._customize_all(c) for c in value._args]
115-
value._kwargs = {
116-
k: self._customize_all(v) for k, v in value._kwargs.items()
117-
}
118-
value._kwonly = {
119-
k: self._customize_all(v) for k, v in value._kwonly.items()
120-
}
148+
return with_original(
149+
CustomCall(
150+
function=self._customize_all(value.function),
151+
args=[self._customize_all(c) for c in value.args],
152+
kwargs={k: self._customize_all(v) for k, v in value.kwargs.items()},
153+
),
154+
value,
155+
)
121156
elif isinstance(value, CustomDefault):
122-
value.value = self._customize_all(value.value)
157+
return with_original(CustomDefault(self._customize_all(value.value)), value)
123158

124159
return value
125160

src/inline_snapshot/_customize/_custom_code.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
from __future__ import annotations
22

3-
import ast
43
import importlib
54
from dataclasses import dataclass
65
from typing import Generator
76

87
from inline_snapshot._adapter_context import AdapterContext
98
from inline_snapshot._change import ChangeBase
109
from inline_snapshot._change import RequiredImport
11-
from inline_snapshot._code_repr import HasRepr
12-
from inline_snapshot._code_repr import value_code_repr
1310
from inline_snapshot._utils import clone
1411

1512
from ._custom import Custom
@@ -45,21 +42,12 @@ def _simplify_module_path(module: str, name: str) -> str:
4542
class CustomCode(Custom):
4643
_imports: list[Import | ImportFrom]
4744

48-
def __init__(self, value, repr_str=None, imports: list[Import | ImportFrom] = []):
45+
def __init__(self, value, repr_str, imports: list[Import | ImportFrom] = []):
4946
assert not isinstance(value, Custom)
5047
value = clone(value)
5148
self._imports = list(imports)
5249

53-
if repr_str is None:
54-
self.repr_str = value_code_repr(value)
55-
56-
try:
57-
ast.parse(self.repr_str)
58-
except SyntaxError:
59-
self.repr_str = HasRepr(type(value), self.repr_str).__repr__()
60-
self._imports.append(ImportFrom("inline_snapshot", "HasRepr"))
61-
else:
62-
self.repr_str = repr_str
50+
self.repr_str = repr_str
6351

6452
self.value = value
6553

src/inline_snapshot/_new_adapter.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,14 @@ def reeval_CustomCode(old_value: CustomCode, value: CustomCode):
122122
return value
123123

124124

125+
def reeval_CustomValue(old_value: CustomCode, value: CustomCode):
126+
if not old_value._eval() == value._eval():
127+
raise UsageError(
128+
"snapshot value should not change. Use Is(...) for dynamic snapshot parts."
129+
)
130+
return value
131+
132+
125133
def reeval_CustomCall(old_value: CustomCall, value: CustomCall):
126134
return CustomCall(
127135
reeval(old_value.function, value.function),
@@ -161,7 +169,7 @@ def compare(
161169
snapshot_value = (
162170
missing
163171
if isinstance(old_value, CustomUndefined)
164-
else old_value.original_value
172+
else getattr(old_value, "original_value", missing)
165173
)
166174

167175
custom_value = self.customize(new_value, snapshot_value)
@@ -190,7 +198,9 @@ def compare(
190198
old_value, old_node, custom_value
191199
)
192200
else:
193-
result = yield from self.compare_CustomCode(old_value, old_node, new_value)
201+
result = yield from self.compare_CustomCode(
202+
old_value, old_node, custom_value
203+
)
194204
return result
195205

196206
def compare_CustomCode(
@@ -272,8 +282,17 @@ def compare_CustomList(
272282
if c in "mx":
273283
old_value_element, old_node_element = next(old)
274284
new_value_element = next(new)
285+
comparison_value = (
286+
new_value_element
287+
if isinstance(new_value_element, CustomUnmanaged)
288+
else getattr(
289+
new_value_element, "original_value", new_value_element._eval()
290+
)
291+
)
275292
v = yield from self.compare(
276-
old_value_element, old_node_element, new_value_element
293+
old_value_element,
294+
old_node_element,
295+
comparison_value,
277296
)
278297
result.append(v)
279298
old_position += 1
@@ -316,7 +335,16 @@ def compare_CustomTuple(
316335

317336
# compare paired elements
318337
for old_elem, old_node_elem, new_elem in zip(old_elts, old_nodes, new_elts):
319-
v = yield from self.compare(old_elem, old_node_elem, new_elem)
338+
comparison_value = (
339+
new_elem
340+
if isinstance(new_elem, CustomUnmanaged)
341+
else getattr(new_elem, "original_value", new_elem._eval())
342+
)
343+
v = yield from self.compare(
344+
old_elem,
345+
old_node_elem,
346+
comparison_value,
347+
)
320348
result.append(v)
321349

322350
# delete surplus old elements

src/inline_snapshot/_snapshot/undecided_value.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Any
33
from typing import Iterator
44

5-
from inline_snapshot._code_repr import mock_repr
65
from inline_snapshot._compare_context import compare_only
76
from inline_snapshot._customize._builder import Builder
87
from inline_snapshot._customize._custom import Custom
@@ -100,8 +99,7 @@ def convert_generic(self, value: Any) -> Custom:
10099
if value is ...:
101100
return CustomUndefined()
102101
else:
103-
with mock_repr(self.context):
104-
result = Builder(self.context, _recursive=False)._get_handler(value)
102+
result = Builder(self.context, _recursive=False)._get_handler(value)
105103
if isinstance(result, CustomCall) and result.function == type(value):
106104
function = self.convert(result.function)
107105
posonly_args = [self.convert(arg) for arg in result.args]

src/inline_snapshot/plugin/_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def customize(
3333
builder: Builder,
3434
local_vars: Dict[str, Any],
3535
global_vars: Dict[str, Any],
36-
snapshot_value: Optional[Any] = None,
36+
snapshot_value: Optional[Any],
3737
) -> Any:
3838
"""
3939
The customize hook is called every time a snapshot value should be converted into code.

tests/test_pydantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,4 +269,4 @@ def test_a():
269269
270270
assert c == snapshot(C(a=5))
271271
"""}),
272-
).run_inline(reported_categories={"update"})
272+
).run_inline()

tests/test_pytest_plugin.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -961,19 +961,19 @@ def test_a():
961961
| - assert 1==snapshot(5) |
962962
| + assert 1==snapshot(1) |
963963
+------------------------------------------------------------------------------+
964-
Do you want to fix these snapshots? [y/n] (n):\
964+
These changes are not applied.
965+
Use --inline-snapshot=fix to apply them, or use the interactive mode with
966+
--inline-snapshot=review\
965967
"""),
966968
returncode=snapshot(1),
967969
stderr=snapshot(""),
968-
changed_files=snapshot({"tests/test_something.py": """\
969-
import pytest
970-
from inline_snapshot import snapshot
971-
972-
def test_a():
973-
assert 1==snapshot(1)
974-
"""}),
975-
stdin=b"y\n",
976-
outcomes={"passed": 1, "errors": 1},
970+
changed_files=snapshot({}),
971+
error="""\
972+
> assert 1==snapshot(5)
973+
E assert 1 == 5
974+
E + where 5 = snapshot(5)
975+
""",
976+
outcomes={"failed": 1, "errors": 1},
977977
)
978978

979979

tests/test_snapshot_value.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@ def test_snapshot_value_in_list():
99
Example(
1010
{
1111
"tests/conftest.py": """\
12-
from inline_snapshot import customize, snapshot
12+
from inline_snapshot import snapshot
13+
from inline_snapshot.plugin import customize
1314
1415
old_new_mapping={}
1516
1617
@customize
1718
def double_if_old_exists(value, builder, snapshot_value):
18-
if isinstance(value, int):
19-
old_new_mapping[value]=str(snapshot_value)
20-
return builder.create_value(value)
19+
if isinstance(value,int):
20+
assert value in (1,2,3,4,5,6,8,0)
21+
old_new_mapping[value]=snapshot_value
22+
return builder.create_code(str(value))
2123
2224
""",
2325
"tests/test_something.py": """\
@@ -30,7 +32,7 @@ def test_it():
3032
assert snapshot([1, 8, 3,4,5,0]) == [1, 2, 3,4,5,6]
3133
from conftest import old_new_mapping
3234
33-
assert old_new_mapping==snapshot()
35+
assert dict(old_new_mapping)==snapshot()
3436
""",
3537
}
3638
).run_pytest(
@@ -49,6 +51,7 @@ def test_it():
4951
assert snapshot([1, 2, 3,4,5,6]) == [1, 2, 3,4,5,6]
5052
from conftest import old_new_mapping
5153
52-
assert old_new_mapping==snapshot({1: "1", 2: "8", 3: "3", 4: "4", 5: "5", 6: "0"})
54+
assert dict(old_new_mapping)==snapshot({1: 1, 2: 8, 3: 3, 4: 4, 5: 5, 6: 0})
5355
"""}),
56+
outcomes={"passed": 1, "errors": 1},
5457
)

0 commit comments

Comments
 (0)