Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ Icons:

## [0.9.0] - 2025-10-23
- 🆕 Added class method for handling objects in LoadSnapshot fixture
- 🆕 Added support for multiple snapshots loaded in the same LoadSnapshot fixture, if multiple functions are added to the depends.
- 🐞 Fix snappylapy logo not showing correctly in documentation
- 🐞 Fix issue when snapshots is loaded from another folder and the path is added to the depending file name

## [0.8.0] - 2025-10-18
- 🆕 Used integration with toolit to make snappylapy commands available for AI coding assistants
Expand Down
30 changes: 23 additions & 7 deletions snappylapy/_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,20 @@
from snappylapy.constants import DEFAULT_SNAPSHOT_BASE_DIR
from snappylapy.exceptions import TestDirectoryNotParametrizedError
from snappylapy.fixtures import Settings
from snappylapy.models import DependingSettings
from snappylapy.session import SnapshotSession
from typing import Any


def _extract_module_name(module_path: str) -> str:
"""
Extract the module name from a dotted module path, returning only the last component.

This is used to strip package paths and keep only the module's filename for snapshot tracking.
"""
return module_path.split(".", maxsplit=1)[-1]


def _get_kwargs_from_depend_function(
depends_function: Callable,
marker_name: str,
Expand Down Expand Up @@ -70,18 +80,24 @@ def snappylapy_settings(request: pytest.FixtureRequest) -> Settings:
# TODO: Add a better error message
msg = "Path output directory cannot be None"
raise ValueError(msg)
settings.depending_snapshots_base_dir = pathlib.Path(path_output_dir)
# settings.depending_snapshots_base_dir = pathlib.Path(path_output_dir)
settings.snapshots_base_dir = pathlib.Path(path_output_dir)
settings.custom_name = path_output_dir.name
# If not parametrized, get the depends from the marker
depends: list = marker.kwargs.get("depends", []) if marker else []
if depends:
input_dir_from_depends = _get_kwargs_from_depend_function(depends[0], "snappylapy", "output_dir")
if input_dir_from_depends:
path_output_dir = pathlib.Path(input_dir_from_depends)
settings.depending_test_filename = depends[0].__module__
settings.depending_test_function = depends[0].__name__
settings.depending_snapshots_base_dir = path_output_dir or DEFAULT_SNAPSHOT_BASE_DIR
for depend in depends:
input_dir_from_depends = _get_kwargs_from_depend_function(depend, "snappylapy", "output_dir")
if input_dir_from_depends:
path_output_dir = pathlib.Path(input_dir_from_depends)
dependency_setting = DependingSettings(
test_filename=_extract_module_name(depend.__module__),
test_function=depend.__name__,
snapshots_base_dir=path_output_dir or DEFAULT_SNAPSHOT_BASE_DIR,
custom_name=settings.custom_name,
)
settings.depending_tests.append(dependency_setting)

return settings


Expand Down
62 changes: 46 additions & 16 deletions snappylapy/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
BytesSerializer,
JsonPickleSerializer,
PandasCsvSerializer,
Serializer,
StringSerializer,
)
from snappylapy.constants import DIRECTORY_NAMES
from snappylapy.session import SnapshotSession
from typing import Any, Protocol, overload
from typing import Any, Protocol, TypeVar, overload

T = TypeVar("T")


class _CallableExpectation(Protocol):
Expand Down Expand Up @@ -375,18 +378,33 @@ class LoadSnapshot:
def __init__(self, settings: Settings) -> None:
"""Do not initialize the LoadSnapshot class directly, should be used through the `load_snapshot` fixture in pytest.""" # noqa: E501
self.settings = settings
self._current_dependency_index = 0

def _read_snapshot(self) -> bytes:
"""Read the snapshot file."""
if not self.settings.depending_snapshots_base_dir:
if self._current_dependency_index >= len(self.settings.depending_tests):
msg = (
f"Attempted to load more dependencies ({self._current_dependency_index + 1}) "
f"than available ({len(self.settings.depending_tests)}). "
"Check your test's dependency configuration."
)
raise IndexError(msg)
if not self.settings.depending_tests[self._current_dependency_index].snapshots_base_dir:
msg = "Depending snapshots base directory is not set."
raise ValueError(msg)
return (
self.settings.depending_snapshots_base_dir
self.settings.depending_tests[self._current_dependency_index].snapshots_base_dir
/ DIRECTORY_NAMES.snapshot_dir_name
/ self.settings.depending_filename
/ self.settings.depending_tests[self._current_dependency_index].filename
).read_bytes()

def _load_and_deserialize(self, filename_extension: str, deserializer: Serializer[T]) -> T:
"""Set filename extension, read, deserialize, and increment dependency index."""
self.settings.depending_tests[self._current_dependency_index].filename_extension = filename_extension
deserialized_data = deserializer.deserialize(self._read_snapshot())
self._current_dependency_index += 1
return deserialized_data

def dict(self) -> dict[Any, Any]:
"""
Load dictionary snapshot.
Expand Down Expand Up @@ -415,8 +433,10 @@ def test_load_snapshot_dict(load_snapshot: LoadSnapshot) -> None:
assert data["bananas"] == 5
```
"""
self.settings.depending_filename_extension = "dict.json"
return JsonPickleSerializer[dict]().deserialize(self._read_snapshot())
return self._load_and_deserialize(
"dict.json",
JsonPickleSerializer[dict](),
)

def list(self) -> list[Any]:
"""
Expand Down Expand Up @@ -451,8 +471,10 @@ def test_next_transformation(load_snapshot: LoadSnapshot, expect: Expect) -> Non
expect(result).to_match_snapshot()
```
"""
self.settings.depending_filename_extension = "list.json"
return JsonPickleSerializer[list[Any]]().deserialize(self._read_snapshot())
return self._load_and_deserialize(
"list.json",
JsonPickleSerializer[list[Any]](),
)

def string(self) -> str:
"""
Expand All @@ -478,8 +500,10 @@ def test_load_snapshot_string(load_snapshot: LoadSnapshot) -> None:
assert data == "Hello, pytest!"
```
"""
self.settings.depending_filename_extension = "string.txt"
return StringSerializer().deserialize(self._read_snapshot())
return self._load_and_deserialize(
"string.txt",
StringSerializer(),
)

def bytes(self) -> bytes:
r"""
Expand All @@ -505,8 +529,10 @@ def test_load_snapshot_bytes(load_snapshot: LoadSnapshot) -> None:
assert data == b"\x01\x02\x03"
```
"""
self.settings.depending_filename_extension = "bytes.txt"
return BytesSerializer().deserialize(self._read_snapshot())
return self._load_and_deserialize(
"bytes.txt",
BytesSerializer(),
)

def dataframe(self) -> DataframeExpect.DataFrame:
"""
Expand All @@ -533,8 +559,10 @@ def test_load_snapshot_dataframe(load_snapshot: LoadSnapshot) -> None:
assert df["numbers"].sum() == 6
```
"""
self.settings.depending_filename_extension = "dataframe.csv"
return PandasCsvSerializer().deserialize(self._read_snapshot())
return self._load_and_deserialize(
"dataframe.csv",
PandasCsvSerializer(),
)

def object(self) -> object:
"""
Expand Down Expand Up @@ -565,5 +593,7 @@ def test_load_snapshot_object(load_snapshot: LoadSnapshot) -> None:
assert obj.value == 42
```
"""
self.settings.depending_filename_extension = "object.json"
return JsonPickleSerializer[object]().deserialize(self._read_snapshot())
return self._load_and_deserialize(
"object.json",
JsonPickleSerializer[object](),
)
61 changes: 36 additions & 25 deletions snappylapy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,40 @@
from __future__ import annotations

import pathlib
from dataclasses import dataclass
from dataclasses import dataclass, field
from snappylapy.constants import DIRECTORY_NAMES


@dataclass
class DependingSettings:
"""Settings for depending on other snapshots. Used for loading snapshots."""

test_filename: str
"""Filename of the test module where the depending test are defined."""

test_function: str
"""Name of the depending test function."""

snapshots_base_dir: pathlib.Path
"""Input base directory for loading snapshots."""

filename_extension: str | None = None
"""Extension of the depending snapshot file."""

custom_name: str | None = None
"""Custom name for the depending snapshot file."""

@property
def filename(self) -> str:
"""Get the depending snapshot filename."""
if not self.filename_extension:
msg = "Missing depending snapshot filename extension."
raise ValueError(msg)
if self.custom_name is not None:
return f"[{self.test_filename}][{self.test_function}][{self.custom_name}].{self.filename_extension}"
return f"[{self.test_filename}][{self.test_function}].{self.filename_extension}"


@dataclass
class Settings:
"""Shared setting for all the strategies for doing snapshot testing."""
Expand All @@ -30,17 +60,12 @@ class Settings:
"""Extension for the output of snapshot file."""

# Configurations for depending
depending_test_filename: str | None = None
"""Filename of the test module where the depending test are defined. Used for loading."""

depending_test_function: str | None = None
"""Name of the depending test function. Used for loading."""
depending_tests: list[DependingSettings] = field(default_factory=list)
"""
Depending tests are used for loading snapshots from other tests.
depending_filename_extension: str | None = None
"""Extension of the depending snapshot file. Used for loading."""

depending_snapshots_base_dir: pathlib.Path | None = None
"""Input base directory for loading snapshots."""
Information about each test the users have specified in a test decorator will be stored here.
"""

@property
def snapshot_dir(self) -> pathlib.Path:
Expand All @@ -58,17 +83,3 @@ def filename(self) -> str:
if self.custom_name is not None:
return f"[{self.test_filename}][{self.test_function}][{self.custom_name}].{self.filename_extension}"
return f"[{self.test_filename}][{self.test_function}].{self.filename_extension}"

@property
def depending_filename(self) -> str:
"""Get the depending snapshot filename."""
if (
not self.depending_test_filename
or not self.depending_test_function
or not self.depending_filename_extension
):
msg = "Missing depending test filename, function or extension."
raise ValueError(msg)
if self.custom_name is not None:
return f"[{self.depending_test_filename}][{self.depending_test_function}][{self.custom_name}].{self.depending_filename_extension}" # noqa: E501
return f"[{self.depending_test_filename}][{self.depending_test_function}].{self.depending_filename_extension}"