@@ -46,6 +46,10 @@ class ItemStatus(Enum):
46
46
SKIPPED = "skipped"
47
47
48
48
49
+ _QueuedWriteExtensionKey = Tuple [Type ["AbstractSyrupyExtension" ], str ]
50
+ _QueuedWriteTestLocationKey = Tuple ["PyTestLocation" , "SnapshotIndex" ]
51
+
52
+
49
53
@dataclass
50
54
class SnapshotSession :
51
55
pytest_session : "pytest.Session"
@@ -62,10 +66,28 @@ class SnapshotSession:
62
66
default_factory = lambda : defaultdict (set )
63
67
)
64
68
65
- _queued_snapshot_writes : Dict [
66
- Tuple [Type ["AbstractSyrupyExtension" ], str ],
67
- List [Tuple ["SerializedData" , "PyTestLocation" , "SnapshotIndex" ]],
68
- ] = field (default_factory = dict )
69
+ # For performance, we buffer snapshot writes in memory before flushing them to disk. In
70
+ # particular, we want to be able to write to a file on disk only once, rather than having to
71
+ # repeatedly rewrite it.
72
+ #
73
+ # That batching leads to using two layers of dicts here: the outer layer represents the
74
+ # extension/file-location pair that will be written, and the inner layer represents the
75
+ # snapshots within that, "indexed" to allow efficient recall.
76
+ _queued_snapshot_writes : DefaultDict [
77
+ _QueuedWriteExtensionKey ,
78
+ Dict [_QueuedWriteTestLocationKey , "SerializedData" ],
79
+ ] = field (default_factory = lambda : defaultdict (dict ))
80
+
81
+ def _snapshot_write_queue_keys (
82
+ self ,
83
+ extension : "AbstractSyrupyExtension" ,
84
+ test_location : "PyTestLocation" ,
85
+ index : "SnapshotIndex" ,
86
+ ) -> Tuple [_QueuedWriteExtensionKey , _QueuedWriteTestLocationKey ]:
87
+ snapshot_location = extension .get_location (
88
+ test_location = test_location , index = index
89
+ )
90
+ return (extension .__class__ , snapshot_location ), (test_location , index )
69
91
70
92
def queue_snapshot_write (
71
93
self ,
@@ -74,13 +96,10 @@ def queue_snapshot_write(
74
96
data : "SerializedData" ,
75
97
index : "SnapshotIndex" ,
76
98
) -> None :
77
- snapshot_location = extension . get_location (
78
- test_location = test_location , index = index
99
+ ext_key , loc_key = self . _snapshot_write_queue_keys (
100
+ extension , test_location , index
79
101
)
80
- key = (extension .__class__ , snapshot_location )
81
- queue = self ._queued_snapshot_writes .get (key , [])
82
- queue .append ((data , test_location , index ))
83
- self ._queued_snapshot_writes [key ] = queue
102
+ self ._queued_snapshot_writes [ext_key ][loc_key ] = data
84
103
85
104
def flush_snapshot_write_queue (self ) -> None :
86
105
for (
@@ -89,9 +108,33 @@ def flush_snapshot_write_queue(self) -> None:
89
108
), queued_write in self ._queued_snapshot_writes .items ():
90
109
if queued_write :
91
110
extension_class .write_snapshot (
92
- snapshot_location = snapshot_location , snapshots = queued_write
111
+ snapshot_location = snapshot_location ,
112
+ snapshots = [
113
+ (data , loc , index )
114
+ for (loc , index ), data in queued_write .items ()
115
+ ],
93
116
)
94
- self ._queued_snapshot_writes = {}
117
+ self ._queued_snapshot_writes .clear ()
118
+
119
+ def recall_snapshot (
120
+ self ,
121
+ extension : "AbstractSyrupyExtension" ,
122
+ test_location : "PyTestLocation" ,
123
+ index : "SnapshotIndex" ,
124
+ ) -> Optional ["SerializedData" ]:
125
+ """Find the current value of the snapshot, for this session, either a pending write or the actual snapshot."""
126
+
127
+ ext_key , loc_key = self ._snapshot_write_queue_keys (
128
+ extension , test_location , index
129
+ )
130
+ data = self ._queued_snapshot_writes [ext_key ].get (loc_key )
131
+ if data is not None :
132
+ return data
133
+
134
+ # No matching write queued, so just read the snapshot directly:
135
+ return extension .read_snapshot (
136
+ test_location = test_location , index = index , session_id = str (id (self ))
137
+ )
95
138
96
139
@property
97
140
def update_snapshots (self ) -> bool :
0 commit comments