2424 BytesSerializer ,
2525 JsonPickleSerializer ,
2626 PandasCsvSerializer ,
27+ Serializer ,
2728 StringSerializer ,
2829)
2930from snappylapy .constants import DIRECTORY_NAMES
3031from snappylapy .session import SnapshotSession
31- from typing import Any , Protocol , overload
32+ from typing import Any , Protocol , TypeVar , overload
33+
34+ T = TypeVar ("T" )
3235
3336
3437class _CallableExpectation (Protocol ):
@@ -375,18 +378,33 @@ class LoadSnapshot:
375378 def __init__ (self , settings : Settings ) -> None :
376379 """Do not initialize the LoadSnapshot class directly, should be used through the `load_snapshot` fixture in pytest.""" # noqa: E501
377380 self .settings = settings
381+ self ._current_dependency_index = 0
378382
379383 def _read_snapshot (self ) -> bytes :
380384 """Read the snapshot file."""
381- if not self .settings .depending_snapshots_base_dir :
385+ if self ._current_dependency_index >= len (self .settings .depending_tests ):
386+ msg = (
387+ f"Attempted to load more dependencies ({ self ._current_dependency_index + 1 } ) "
388+ f"than available ({ len (self .settings .depending_tests )} ). "
389+ "Check your test's dependency configuration."
390+ )
391+ raise IndexError (msg )
392+ if not self .settings .depending_tests [self ._current_dependency_index ].snapshots_base_dir :
382393 msg = "Depending snapshots base directory is not set."
383394 raise ValueError (msg )
384395 return (
385- self .settings .depending_snapshots_base_dir
396+ self .settings .depending_tests [ self . _current_dependency_index ]. snapshots_base_dir
386397 / DIRECTORY_NAMES .snapshot_dir_name
387- / self .settings .depending_filename
398+ / self .settings .depending_tests [ self . _current_dependency_index ]. filename
388399 ).read_bytes ()
389400
401+ def _load_and_deserialize (self , filename_extension : str , deserializer : Serializer [T ]) -> T :
402+ """Set filename extension, read, deserialize, and increment dependency index."""
403+ self .settings .depending_tests [self ._current_dependency_index ].filename_extension = filename_extension
404+ deserialized_data = deserializer .deserialize (self ._read_snapshot ())
405+ self ._current_dependency_index += 1
406+ return deserialized_data
407+
390408 def dict (self ) -> dict [Any , Any ]:
391409 """
392410 Load dictionary snapshot.
@@ -415,8 +433,10 @@ def test_load_snapshot_dict(load_snapshot: LoadSnapshot) -> None:
415433 assert data["bananas"] == 5
416434 ```
417435 """
418- self .settings .depending_filename_extension = "dict.json"
419- return JsonPickleSerializer [dict ]().deserialize (self ._read_snapshot ())
436+ return self ._load_and_deserialize (
437+ "dict.json" ,
438+ JsonPickleSerializer [dict ](),
439+ )
420440
421441 def list (self ) -> list [Any ]:
422442 """
@@ -451,8 +471,10 @@ def test_next_transformation(load_snapshot: LoadSnapshot, expect: Expect) -> Non
451471 expect(result).to_match_snapshot()
452472 ```
453473 """
454- self .settings .depending_filename_extension = "list.json"
455- return JsonPickleSerializer [list [Any ]]().deserialize (self ._read_snapshot ())
474+ return self ._load_and_deserialize (
475+ "list.json" ,
476+ JsonPickleSerializer [list [Any ]](),
477+ )
456478
457479 def string (self ) -> str :
458480 """
@@ -478,8 +500,10 @@ def test_load_snapshot_string(load_snapshot: LoadSnapshot) -> None:
478500 assert data == "Hello, pytest!"
479501 ```
480502 """
481- self .settings .depending_filename_extension = "string.txt"
482- return StringSerializer ().deserialize (self ._read_snapshot ())
503+ return self ._load_and_deserialize (
504+ "string.txt" ,
505+ StringSerializer (),
506+ )
483507
484508 def bytes (self ) -> bytes :
485509 r"""
@@ -505,8 +529,10 @@ def test_load_snapshot_bytes(load_snapshot: LoadSnapshot) -> None:
505529 assert data == b"\x01\x02\x03"
506530 ```
507531 """
508- self .settings .depending_filename_extension = "bytes.txt"
509- return BytesSerializer ().deserialize (self ._read_snapshot ())
532+ return self ._load_and_deserialize (
533+ "bytes.txt" ,
534+ BytesSerializer (),
535+ )
510536
511537 def dataframe (self ) -> DataframeExpect .DataFrame :
512538 """
@@ -533,8 +559,10 @@ def test_load_snapshot_dataframe(load_snapshot: LoadSnapshot) -> None:
533559 assert df["numbers"].sum() == 6
534560 ```
535561 """
536- self .settings .depending_filename_extension = "dataframe.csv"
537- return PandasCsvSerializer ().deserialize (self ._read_snapshot ())
562+ return self ._load_and_deserialize (
563+ "dataframe.csv" ,
564+ PandasCsvSerializer (),
565+ )
538566
539567 def object (self ) -> object :
540568 """
@@ -565,5 +593,7 @@ def test_load_snapshot_object(load_snapshot: LoadSnapshot) -> None:
565593 assert obj.value == 42
566594 ```
567595 """
568- self .settings .depending_filename_extension = "object.json"
569- return JsonPickleSerializer [object ]().deserialize (self ._read_snapshot ())
596+ return self ._load_and_deserialize (
597+ "object.json" ,
598+ JsonPickleSerializer [object ](),
599+ )
0 commit comments