Skip to content

Commit b12e9ce

Browse files
Support loading multiple dependencies in LoadSnapshot fixture (#14)
Enhance the LoadSnapshot fixture to accommodate multiple dependencies, allowing for more flexible snapshot loading. Refactor code for clarity and fix linting issues. Update the changelog to reflect these changes.
2 parents 981edc0 + 45368d3 commit b12e9ce

File tree

4 files changed

+107
-48
lines changed

4 files changed

+107
-48
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ Icons:
1212

1313
## [0.9.0] - 2025-10-23
1414
- 🆕 Added class method for handling objects in LoadSnapshot fixture
15+
- 🆕 Added support for multiple snapshots loaded in the same LoadSnapshot fixture, if multiple functions are added to the depends.
1516
- 🐞 Fix snappylapy logo not showing correctly in documentation
17+
- 🐞 Fix issue when snapshots is loaded from another folder and the path is added to the depending file name
1618

1719
## [0.8.0] - 2025-10-18
1820
- 🆕 Used integration with toolit to make snappylapy commands available for AI coding assistants

snappylapy/_plugin.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,20 @@
1313
from snappylapy.constants import DEFAULT_SNAPSHOT_BASE_DIR
1414
from snappylapy.exceptions import TestDirectoryNotParametrizedError
1515
from snappylapy.fixtures import Settings
16+
from snappylapy.models import DependingSettings
1617
from snappylapy.session import SnapshotSession
1718
from typing import Any
1819

1920

21+
def _extract_module_name(module_path: str) -> str:
22+
"""
23+
Extract the module name from a dotted module path, returning only the last component.
24+
25+
This is used to strip package paths and keep only the module's filename for snapshot tracking.
26+
"""
27+
return module_path.split(".", maxsplit=1)[-1]
28+
29+
2030
def _get_kwargs_from_depend_function(
2131
depends_function: Callable,
2232
marker_name: str,
@@ -70,18 +80,24 @@ def snappylapy_settings(request: pytest.FixtureRequest) -> Settings:
7080
# TODO: Add a better error message
7181
msg = "Path output directory cannot be None"
7282
raise ValueError(msg)
73-
settings.depending_snapshots_base_dir = pathlib.Path(path_output_dir)
83+
# settings.depending_snapshots_base_dir = pathlib.Path(path_output_dir)
7484
settings.snapshots_base_dir = pathlib.Path(path_output_dir)
7585
settings.custom_name = path_output_dir.name
7686
# If not parametrized, get the depends from the marker
7787
depends: list = marker.kwargs.get("depends", []) if marker else []
7888
if depends:
79-
input_dir_from_depends = _get_kwargs_from_depend_function(depends[0], "snappylapy", "output_dir")
80-
if input_dir_from_depends:
81-
path_output_dir = pathlib.Path(input_dir_from_depends)
82-
settings.depending_test_filename = depends[0].__module__
83-
settings.depending_test_function = depends[0].__name__
84-
settings.depending_snapshots_base_dir = path_output_dir or DEFAULT_SNAPSHOT_BASE_DIR
89+
for depend in depends:
90+
input_dir_from_depends = _get_kwargs_from_depend_function(depend, "snappylapy", "output_dir")
91+
if input_dir_from_depends:
92+
path_output_dir = pathlib.Path(input_dir_from_depends)
93+
dependency_setting = DependingSettings(
94+
test_filename=_extract_module_name(depend.__module__),
95+
test_function=depend.__name__,
96+
snapshots_base_dir=path_output_dir or DEFAULT_SNAPSHOT_BASE_DIR,
97+
custom_name=settings.custom_name,
98+
)
99+
settings.depending_tests.append(dependency_setting)
100+
85101
return settings
86102

87103

snappylapy/fixtures.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@
2424
BytesSerializer,
2525
JsonPickleSerializer,
2626
PandasCsvSerializer,
27+
Serializer,
2728
StringSerializer,
2829
)
2930
from snappylapy.constants import DIRECTORY_NAMES
3031
from snappylapy.session import SnapshotSession
31-
from typing import Any, Protocol, overload
32+
from typing import Any, Protocol, TypeVar, overload
33+
34+
T = TypeVar("T")
3235

3336

3437
class _CallableExpectation(Protocol):
@@ -375,18 +378,33 @@ class LoadSnapshot:
375378
def __init__(self, settings: Settings) -> None:
376379
"""Do not initialize the LoadSnapshot class directly, should be used through the `load_snapshot` fixture in pytest.""" # noqa: E501
377380
self.settings = settings
381+
self._current_dependency_index = 0
378382

379383
def _read_snapshot(self) -> bytes:
380384
"""Read the snapshot file."""
381-
if not self.settings.depending_snapshots_base_dir:
385+
if self._current_dependency_index >= len(self.settings.depending_tests):
386+
msg = (
387+
f"Attempted to load more dependencies ({self._current_dependency_index + 1}) "
388+
f"than available ({len(self.settings.depending_tests)}). "
389+
"Check your test's dependency configuration."
390+
)
391+
raise IndexError(msg)
392+
if not self.settings.depending_tests[self._current_dependency_index].snapshots_base_dir:
382393
msg = "Depending snapshots base directory is not set."
383394
raise ValueError(msg)
384395
return (
385-
self.settings.depending_snapshots_base_dir
396+
self.settings.depending_tests[self._current_dependency_index].snapshots_base_dir
386397
/ DIRECTORY_NAMES.snapshot_dir_name
387-
/ self.settings.depending_filename
398+
/ self.settings.depending_tests[self._current_dependency_index].filename
388399
).read_bytes()
389400

401+
def _load_and_deserialize(self, filename_extension: str, deserializer: Serializer[T]) -> T:
402+
"""Set filename extension, read, deserialize, and increment dependency index."""
403+
self.settings.depending_tests[self._current_dependency_index].filename_extension = filename_extension
404+
deserialized_data = deserializer.deserialize(self._read_snapshot())
405+
self._current_dependency_index += 1
406+
return deserialized_data
407+
390408
def dict(self) -> dict[Any, Any]:
391409
"""
392410
Load dictionary snapshot.
@@ -415,8 +433,10 @@ def test_load_snapshot_dict(load_snapshot: LoadSnapshot) -> None:
415433
assert data["bananas"] == 5
416434
```
417435
"""
418-
self.settings.depending_filename_extension = "dict.json"
419-
return JsonPickleSerializer[dict]().deserialize(self._read_snapshot())
436+
return self._load_and_deserialize(
437+
"dict.json",
438+
JsonPickleSerializer[dict](),
439+
)
420440

421441
def list(self) -> list[Any]:
422442
"""
@@ -451,8 +471,10 @@ def test_next_transformation(load_snapshot: LoadSnapshot, expect: Expect) -> Non
451471
expect(result).to_match_snapshot()
452472
```
453473
"""
454-
self.settings.depending_filename_extension = "list.json"
455-
return JsonPickleSerializer[list[Any]]().deserialize(self._read_snapshot())
474+
return self._load_and_deserialize(
475+
"list.json",
476+
JsonPickleSerializer[list[Any]](),
477+
)
456478

457479
def string(self) -> str:
458480
"""
@@ -478,8 +500,10 @@ def test_load_snapshot_string(load_snapshot: LoadSnapshot) -> None:
478500
assert data == "Hello, pytest!"
479501
```
480502
"""
481-
self.settings.depending_filename_extension = "string.txt"
482-
return StringSerializer().deserialize(self._read_snapshot())
503+
return self._load_and_deserialize(
504+
"string.txt",
505+
StringSerializer(),
506+
)
483507

484508
def bytes(self) -> bytes:
485509
r"""
@@ -505,8 +529,10 @@ def test_load_snapshot_bytes(load_snapshot: LoadSnapshot) -> None:
505529
assert data == b"\x01\x02\x03"
506530
```
507531
"""
508-
self.settings.depending_filename_extension = "bytes.txt"
509-
return BytesSerializer().deserialize(self._read_snapshot())
532+
return self._load_and_deserialize(
533+
"bytes.txt",
534+
BytesSerializer(),
535+
)
510536

511537
def dataframe(self) -> DataframeExpect.DataFrame:
512538
"""
@@ -533,8 +559,10 @@ def test_load_snapshot_dataframe(load_snapshot: LoadSnapshot) -> None:
533559
assert df["numbers"].sum() == 6
534560
```
535561
"""
536-
self.settings.depending_filename_extension = "dataframe.csv"
537-
return PandasCsvSerializer().deserialize(self._read_snapshot())
562+
return self._load_and_deserialize(
563+
"dataframe.csv",
564+
PandasCsvSerializer(),
565+
)
538566

539567
def object(self) -> object:
540568
"""
@@ -565,5 +593,7 @@ def test_load_snapshot_object(load_snapshot: LoadSnapshot) -> None:
565593
assert obj.value == 42
566594
```
567595
"""
568-
self.settings.depending_filename_extension = "object.json"
569-
return JsonPickleSerializer[object]().deserialize(self._read_snapshot())
596+
return self._load_and_deserialize(
597+
"object.json",
598+
JsonPickleSerializer[object](),
599+
)

snappylapy/models.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,40 @@
33
from __future__ import annotations
44

55
import pathlib
6-
from dataclasses import dataclass
6+
from dataclasses import dataclass, field
77
from snappylapy.constants import DIRECTORY_NAMES
88

99

10+
@dataclass
11+
class DependingSettings:
12+
"""Settings for depending on other snapshots. Used for loading snapshots."""
13+
14+
test_filename: str
15+
"""Filename of the test module where the depending test are defined."""
16+
17+
test_function: str
18+
"""Name of the depending test function."""
19+
20+
snapshots_base_dir: pathlib.Path
21+
"""Input base directory for loading snapshots."""
22+
23+
filename_extension: str | None = None
24+
"""Extension of the depending snapshot file."""
25+
26+
custom_name: str | None = None
27+
"""Custom name for the depending snapshot file."""
28+
29+
@property
30+
def filename(self) -> str:
31+
"""Get the depending snapshot filename."""
32+
if not self.filename_extension:
33+
msg = "Missing depending snapshot filename extension."
34+
raise ValueError(msg)
35+
if self.custom_name is not None:
36+
return f"[{self.test_filename}][{self.test_function}][{self.custom_name}].{self.filename_extension}"
37+
return f"[{self.test_filename}][{self.test_function}].{self.filename_extension}"
38+
39+
1040
@dataclass
1141
class Settings:
1242
"""Shared setting for all the strategies for doing snapshot testing."""
@@ -30,17 +60,12 @@ class Settings:
3060
"""Extension for the output of snapshot file."""
3161

3262
# Configurations for depending
33-
depending_test_filename: str | None = None
34-
"""Filename of the test module where the depending test are defined. Used for loading."""
35-
36-
depending_test_function: str | None = None
37-
"""Name of the depending test function. Used for loading."""
63+
depending_tests: list[DependingSettings] = field(default_factory=list)
64+
"""
65+
Depending tests are used for loading snapshots from other tests.
3866
39-
depending_filename_extension: str | None = None
40-
"""Extension of the depending snapshot file. Used for loading."""
41-
42-
depending_snapshots_base_dir: pathlib.Path | None = None
43-
"""Input base directory for loading snapshots."""
67+
Information about each test the users have specified in a test decorator will be stored here.
68+
"""
4469

4570
@property
4671
def snapshot_dir(self) -> pathlib.Path:
@@ -58,17 +83,3 @@ def filename(self) -> str:
5883
if self.custom_name is not None:
5984
return f"[{self.test_filename}][{self.test_function}][{self.custom_name}].{self.filename_extension}"
6085
return f"[{self.test_filename}][{self.test_function}].{self.filename_extension}"
61-
62-
@property
63-
def depending_filename(self) -> str:
64-
"""Get the depending snapshot filename."""
65-
if (
66-
not self.depending_test_filename
67-
or not self.depending_test_function
68-
or not self.depending_filename_extension
69-
):
70-
msg = "Missing depending test filename, function or extension."
71-
raise ValueError(msg)
72-
if self.custom_name is not None:
73-
return f"[{self.depending_test_filename}][{self.depending_test_function}][{self.custom_name}].{self.depending_filename_extension}" # noqa: E501
74-
return f"[{self.depending_test_filename}][{self.depending_test_function}].{self.depending_filename_extension}"

0 commit comments

Comments
 (0)