Skip to content

Commit 0f6bb55

Browse files
authored
fix: check current session's pending-write queue when recalling snapshots (e.g. diffing) (#927)
* fix: check current session's pending-write queue when recalling snapshots (e.g. diffing) * Make PyTestLocation hashable * Explicitly set methodname to None for doctests ----------------------------------------------------------------------------------- benchmark: 3 tests ----------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ test_1000x_reads 666.9710 (1.0) 748.6652 (1.0) 705.2418 (1.0) 37.2862 (1.0) 703.0552 (1.0) 70.1912 (1.07) 2;0 1.4180 (1.0) 5 1 test_standard 669.7840 (1.00) 843.3747 (1.13) 733.8905 (1.04) 68.2257 (1.83) 705.8282 (1.00) 85.6269 (1.30) 1;0 1.3626 (0.96) 5 1 test_1000x_writes 793.8229 (1.19) 937.1953 (1.25) 850.9716 (1.21) 54.4067 (1.46) 847.3260 (1.21) 65.9041 (1.0) 2;0 1.1751 (0.83) 5 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ * Queue writes with a dict for O(1) look-ups Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_1000x_reads 625.5781 (1.0) 887.4346 (1.0) 694.6221 (1.0) 109.0048 (1.0) 658.3128 (1.0) 87.7517 (1.0) 1;1 1.4396 (1.0) 5 1 test_1000x_writes 637.3099 (1.02) 1,021.0924 (1.15) 812.9789 (1.17) 150.2342 (1.38) 757.7635 (1.15) 215.9572 (2.46) 2;0 1.2300 (0.85) 5 1 test_standard 694.1814 (1.11) 1,037.9224 (1.17) 845.1463 (1.22) 136.2068 (1.25) 785.6973 (1.19) 194.9636 (2.22) 2;0 1.1832 (0.82) 5 1 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- * Use type aliases * return both keys from _snapshot_write_queue_key * Use a defaultdict * Update comments
1 parent ef8189c commit 0f6bb55

File tree

4 files changed

+137
-27
lines changed

4 files changed

+137
-27
lines changed

src/syrupy/assertion.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -377,11 +377,7 @@ def _recall_data(
377377
) -> Tuple[Optional["SerializableData"], bool]:
378378
try:
379379
return (
380-
self.extension.read_snapshot(
381-
test_location=self.test_location,
382-
index=index,
383-
session_id=str(id(self.session)),
384-
),
380+
self.session.recall_snapshot(self.extension, self.test_location, index),
385381
False,
386382
)
387383
except SnapshotDoesNotExist:

src/syrupy/location.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from syrupy.constants import PYTEST_NODE_SEP
1414

1515

16-
@dataclass
16+
@dataclass(frozen=True)
1717
class PyTestLocation:
1818
item: "pytest.Item"
1919
nodename: Optional[str] = field(init=False)
@@ -23,27 +23,42 @@ class PyTestLocation:
2323
filepath: str = field(init=False)
2424

2525
def __post_init__(self) -> None:
26+
# NB. we're in a frozen dataclass, but need to transform the values that the caller
27+
# supplied... we do so by (ab)using object.__setattr__ to forcibly set the attributes. (See
28+
# rejected PEP-0712 for an example of a better way to handle this.)
29+
#
30+
# This is safe because this all happens during initialization: `self` hasn't been hashed
31+
# (or, e.g., stored in a dict), so the mutation won't be noticed.
2632
if self.is_doctest:
2733
return self.__attrs_post_init_doc__()
2834
self.__attrs_post_init_def__()
2935

3036
def __attrs_post_init_def__(self) -> None:
3137
node_path: Path = getattr(self.item, "path") # noqa: B009
32-
self.filepath = str(node_path.absolute())
38+
# See __post_init__ for discussion of object.__setattr__
39+
object.__setattr__(self, "filepath", str(node_path.absolute()))
3340
obj = getattr(self.item, "obj") # noqa: B009
34-
self.modulename = obj.__module__
35-
self.methodname = obj.__name__
36-
self.nodename = getattr(self.item, "name", None)
37-
self.testname = self.nodename or self.methodname
41+
object.__setattr__(self, "modulename", obj.__module__)
42+
object.__setattr__(self, "methodname", obj.__name__)
43+
object.__setattr__(self, "nodename", getattr(self.item, "name", None))
44+
object.__setattr__(self, "testname", self.nodename or self.methodname)
3845

3946
def __attrs_post_init_doc__(self) -> None:
4047
doctest = getattr(self.item, "dtest") # noqa: B009
41-
self.filepath = doctest.filename
48+
# See __post_init__ for discussion of object.__setattr__
49+
object.__setattr__(self, "filepath", doctest.filename)
4250
test_relfile, test_node = self.nodeid.split(PYTEST_NODE_SEP)
4351
test_relpath = Path(test_relfile)
44-
self.modulename = ".".join([*test_relpath.parent.parts, test_relpath.stem])
45-
self.nodename = test_node.replace(f"{self.modulename}.", "")
46-
self.testname = self.nodename or self.methodname
52+
object.__setattr__(
53+
self,
54+
"modulename",
55+
".".join([*test_relpath.parent.parts, test_relpath.stem]),
56+
)
57+
object.__setattr__(self, "methodname", None)
58+
object.__setattr__(
59+
self, "nodename", test_node.replace(f"{self.modulename}.", "")
60+
)
61+
object.__setattr__(self, "testname", self.nodename or self.methodname)
4762

4863
@property
4964
def classname(self) -> Optional[str]:

src/syrupy/session.py

+55-12
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ class ItemStatus(Enum):
4646
SKIPPED = "skipped"
4747

4848

49+
_QueuedWriteExtensionKey = Tuple[Type["AbstractSyrupyExtension"], str]
50+
_QueuedWriteTestLocationKey = Tuple["PyTestLocation", "SnapshotIndex"]
51+
52+
4953
@dataclass
5054
class SnapshotSession:
5155
pytest_session: "pytest.Session"
@@ -62,10 +66,28 @@ class SnapshotSession:
6266
default_factory=lambda: defaultdict(set)
6367
)
6468

65-
_queued_snapshot_writes: Dict[
66-
Tuple[Type["AbstractSyrupyExtension"], str],
67-
List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]],
68-
] = field(default_factory=dict)
69+
# For performance, we buffer snapshot writes in memory before flushing them to disk. In
70+
# particular, we want to be able to write to a file on disk only once, rather than having to
71+
# repeatedly rewrite it.
72+
#
73+
# That batching leads to using two layers of dicts here: the outer layer represents the
74+
# extension/file-location pair that will be written, and the inner layer represents the
75+
# snapshots within that, "indexed" to allow efficient recall.
76+
_queued_snapshot_writes: DefaultDict[
77+
_QueuedWriteExtensionKey,
78+
Dict[_QueuedWriteTestLocationKey, "SerializedData"],
79+
] = field(default_factory=lambda: defaultdict(dict))
80+
81+
def _snapshot_write_queue_keys(
82+
self,
83+
extension: "AbstractSyrupyExtension",
84+
test_location: "PyTestLocation",
85+
index: "SnapshotIndex",
86+
) -> Tuple[_QueuedWriteExtensionKey, _QueuedWriteTestLocationKey]:
87+
snapshot_location = extension.get_location(
88+
test_location=test_location, index=index
89+
)
90+
return (extension.__class__, snapshot_location), (test_location, index)
6991

7092
def queue_snapshot_write(
7193
self,
@@ -74,13 +96,10 @@ def queue_snapshot_write(
7496
data: "SerializedData",
7597
index: "SnapshotIndex",
7698
) -> None:
77-
snapshot_location = extension.get_location(
78-
test_location=test_location, index=index
99+
ext_key, loc_key = self._snapshot_write_queue_keys(
100+
extension, test_location, index
79101
)
80-
key = (extension.__class__, snapshot_location)
81-
queue = self._queued_snapshot_writes.get(key, [])
82-
queue.append((data, test_location, index))
83-
self._queued_snapshot_writes[key] = queue
102+
self._queued_snapshot_writes[ext_key][loc_key] = data
84103

85104
def flush_snapshot_write_queue(self) -> None:
86105
for (
@@ -89,9 +108,33 @@ def flush_snapshot_write_queue(self) -> None:
89108
), queued_write in self._queued_snapshot_writes.items():
90109
if queued_write:
91110
extension_class.write_snapshot(
92-
snapshot_location=snapshot_location, snapshots=queued_write
111+
snapshot_location=snapshot_location,
112+
snapshots=[
113+
(data, loc, index)
114+
for (loc, index), data in queued_write.items()
115+
],
93116
)
94-
self._queued_snapshot_writes = {}
117+
self._queued_snapshot_writes.clear()
118+
119+
def recall_snapshot(
120+
self,
121+
extension: "AbstractSyrupyExtension",
122+
test_location: "PyTestLocation",
123+
index: "SnapshotIndex",
124+
) -> Optional["SerializedData"]:
125+
"""Find the current value of the snapshot, for this session, either a pending write or the actual snapshot."""
126+
127+
ext_key, loc_key = self._snapshot_write_queue_keys(
128+
extension, test_location, index
129+
)
130+
data = self._queued_snapshot_writes[ext_key].get(loc_key)
131+
if data is not None:
132+
return data
133+
134+
# No matching write queued, so just read the snapshot directly:
135+
return extension.read_snapshot(
136+
test_location=test_location, index=index, session_id=str(id(self))
137+
)
95138

96139
@property
97140
def update_snapshots(self) -> bool:
+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
3+
_TEST = """
4+
def test_foo(snapshot):
5+
assert {**base} == snapshot(name="a")
6+
assert {**base, **extra} == snapshot(name="b", diff="a")
7+
"""
8+
9+
10+
def _make_file(testdir, base, extra):
11+
testdir.makepyfile(
12+
test_file="\n\n".join([f"base = {base!r}", f"extra = {extra!r}", _TEST])
13+
)
14+
15+
16+
def _run_test(testdir, base, extra, expected_update_lines):
17+
_make_file(testdir, base=base, extra=extra)
18+
19+
# Run with --snapshot-update, to generate/update snapshots:
20+
result = testdir.runpytest(
21+
"-v",
22+
"--snapshot-update",
23+
)
24+
result.stdout.re_match_lines((expected_update_lines,))
25+
assert result.ret == 0
26+
27+
# Run without --snapshot-update, to validate the snapshots are actually up-to-date
28+
result = testdir.runpytest("-v")
29+
result.stdout.re_match_lines((r"2 snapshots passed\.",))
30+
assert result.ret == 0
31+
32+
33+
def test_diff_lifecycle(testdir) -> pytest.Testdir:
34+
# first: create both snapshots completely from scratch
35+
_run_test(
36+
testdir,
37+
base={"A": 1},
38+
extra={"X": 10},
39+
expected_update_lines=r"2 snapshots generated\.",
40+
)
41+
42+
# second: edit the base data, to change the data for both snapshots (only changes the serialized output for the base snapshot `a`).
43+
_run_test(
44+
testdir,
45+
base={"A": 1, "B": 2},
46+
extra={"X": 10},
47+
expected_update_lines=r"1 snapshot passed. 1 snapshot updated\.",
48+
)
49+
50+
# third: edit just the extra data (only changes the serialized output for the diff snapshot `b`)
51+
_run_test(
52+
testdir,
53+
base={"A": 1, "B": 2},
54+
extra={"X": 10, "Y": 20},
55+
expected_update_lines=r"1 snapshot passed. 1 snapshot updated\.",
56+
)

0 commit comments

Comments
 (0)