Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions skore-local-project/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ convention = "numpy"

[tool.mypy]
exclude = ["hatch/", "tests/"]
strict = true

[[tool.mypy.overrides]]
follow_untyped_imports = true
Expand Down
21 changes: 13 additions & 8 deletions skore-local-project/src/skore_local_project/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,26 @@
from joblib import hash

if TYPE_CHECKING:
from typing import Any
from collections.abc import Generator
from typing import Any, Literal

from skore import CrossValidationReport, EstimatorReport


def cast_to_float(value: Any) -> float | None:
"""Cast value to float."""
with suppress(TypeError):
if isfinite(value := float(value)):
return value
float_value = float(value)

if isfinite(float_value):
return float_value

return None


def report_type(report: EstimatorReport | CrossValidationReport):
def report_type(
report: EstimatorReport | CrossValidationReport,
) -> Literal["cross-validation", "estimator"]:
"""Human readable type of a report."""
from skore import CrossValidationReport, EstimatorReport

Expand Down Expand Up @@ -76,12 +81,12 @@ class ReportMetadata(ABC):
report_type: str = field(init=False)
dataset: str = field(init=False)

def __iter__(self):
def __iter__(self) -> Generator[tuple[str, str], None, None]:
"""Iterate over the metadata."""
for field in fields(self): # noqa: F402
yield (field.name, getattr(self, field.name))

def __post_init__(self, report: EstimatorReport | CrossValidationReport):
def __post_init__(self, report: EstimatorReport | CrossValidationReport) -> None:
"""Initialize dynamic fields."""
self.date = datetime.now(timezone.utc).isoformat()
self.learner = report.estimator_name_
Expand All @@ -106,7 +111,7 @@ def metric(report: EstimatorReport, name: str) -> float | None:

return cast_to_float(getattr(report.metrics, name)(data_source="test"))

def __post_init__(self, report: EstimatorReport): # type: ignore[override]
def __post_init__(self, report: EstimatorReport) -> None: # type: ignore[override]
"""Initialize dynamic fields."""
super().__post_init__(report)

Expand Down Expand Up @@ -152,7 +157,7 @@ def timing(report: CrossValidationReport, label: str) -> float | None:

return cast_to_float(series.iloc[0])

def __post_init__(self, report: CrossValidationReport): # type: ignore[override]
def __post_init__(self, report: CrossValidationReport) -> None: # type: ignore[override]
"""Initialize dynamic fields."""
super().__post_init__(report)

Expand Down
34 changes: 25 additions & 9 deletions skore-local-project/src/skore_local_project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from functools import wraps
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ParamSpec, Protocol, TypeVar, cast, runtime_checkable
from uuid import uuid4

import joblib
Expand All @@ -15,7 +15,12 @@
from .metadata import CrossValidationReportMetadata, EstimatorReportMetadata
from .storage import DiskCacheStorage

P = ParamSpec("P")
R = TypeVar("R")


if TYPE_CHECKING:
from collections.abc import Callable
from typing import TypedDict

from skore import CrossValidationReport, EstimatorReport
Expand All @@ -40,17 +45,26 @@ class Metadata(TypedDict): # noqa: D101
predict_time_mean: float | None


def ensure_project_is_not_deleted(method):
def ensure_project_is_not_deleted(method: Callable[P, R]) -> Callable[P, R]:
"""Ensure project is not deleted, before executing any other operation."""

@runtime_checkable
class Project(Protocol):
name: str
_Project__projects_storage: DiskCacheStorage

@wraps(method)
def wrapper(self, *args, **kwargs):
if self.name not in self._Project__projects_storage:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
project = args[0]

assert isinstance(project, Project), "You can only wrap `Project` methods"

if project.name not in project._Project__projects_storage:
raise RuntimeError(
f"Skore could not proceed because {self!r} does not exist anymore."
f"Skore could not proceed because {project!r} does not exist anymore."
)

return method(self, *args, **kwargs)
return method(*args, **kwargs)

return wrapper

Expand Down Expand Up @@ -183,7 +197,7 @@ def pickle(report: EstimatorReport | CrossValidationReport) -> tuple[str, bytes]
return pickle_hash, pickle_bytes

@ensure_project_is_not_deleted
def put(self, key: str, report: EstimatorReport | CrossValidationReport):
def put(self, key: str, report: EstimatorReport | CrossValidationReport) -> None:
"""
Put a key-report pair to the local project.

Expand Down Expand Up @@ -238,7 +252,9 @@ def get(self, id: str) -> EstimatorReport | CrossValidationReport:
"""Get a persisted report by its id."""
if id in self.__artifacts_storage:
with io.BytesIO(self.__artifacts_storage[id]) as stream:
return joblib.load(stream)
return cast(
"EstimatorReport | CrossValidationReport", joblib.load(stream)
)

raise KeyError(id)

Expand Down Expand Up @@ -275,7 +291,7 @@ def __repr__(self) -> str: # noqa: D105
)

@staticmethod
def delete(name: str, *, workspace: Path | None = None):
def delete(name: str, *, workspace: Path | None = None) -> None:
r"""
Delete a local project.

Expand Down
4 changes: 2 additions & 2 deletions skore-local-project/src/skore_local_project/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __getitem__(self, key: str) -> Any:
with Cache(self.directory) as storage:
return storage[key]

def __setitem__(self, key: str, value: Any):
def __setitem__(self, key: str, value: Any) -> None:
"""
Set an item in the storage.

Expand All @@ -120,7 +120,7 @@ def __setitem__(self, key: str, value: Any):
with Cache(self.directory) as storage:
storage[key] = value

def __delitem__(self, key: str):
def __delitem__(self, key: str) -> None:
"""
Delete an item from the storage.

Expand Down
Loading