diff --git a/skore-local-project/pyproject.toml b/skore-local-project/pyproject.toml index 8b8b4f451..23a91124a 100644 --- a/skore-local-project/pyproject.toml +++ b/skore-local-project/pyproject.toml @@ -100,6 +100,7 @@ convention = "numpy" [tool.mypy] exclude = ["hatch/", "tests/"] +strict = true [[tool.mypy.overrides]] follow_untyped_imports = true diff --git a/skore-local-project/src/skore_local_project/metadata.py b/skore-local-project/src/skore_local_project/metadata.py index 9d64dd758..1dfcf376d 100644 --- a/skore-local-project/src/skore_local_project/metadata.py +++ b/skore-local-project/src/skore_local_project/metadata.py @@ -12,7 +12,8 @@ from joblib import hash if TYPE_CHECKING: - from typing import Any + from collections.abc import Generator + from typing import Any, Literal from skore import CrossValidationReport, EstimatorReport @@ -20,13 +21,17 @@ def cast_to_float(value: Any) -> float | None: """Cast value to float.""" with suppress(TypeError): - if isfinite(value := float(value)): - return value + float_value = float(value) + + if isfinite(float_value): + return float_value return None -def report_type(report: EstimatorReport | CrossValidationReport): +def report_type( + report: EstimatorReport | CrossValidationReport, +) -> Literal["cross-validation", "estimator"]: """Human readable type of a report.""" from skore import CrossValidationReport, EstimatorReport @@ -76,12 +81,12 @@ class ReportMetadata(ABC): report_type: str = field(init=False) dataset: str = field(init=False) - def __iter__(self): + def __iter__(self) -> Generator[tuple[str, str], None, None]: """Iterate over the metadata.""" for field in fields(self): # noqa: F402 yield (field.name, getattr(self, field.name)) - def __post_init__(self, report: EstimatorReport | CrossValidationReport): + def __post_init__(self, report: EstimatorReport | CrossValidationReport) -> None: """Initialize dynamic fields.""" self.date = datetime.now(timezone.utc).isoformat() self.learner = report.estimator_name_ @@ -106,7 +111,7 @@ def metric(report: EstimatorReport, name: str) -> float | None: return cast_to_float(getattr(report.metrics, name)(data_source="test")) - def __post_init__(self, report: EstimatorReport): # type: ignore[override] + def __post_init__(self, report: EstimatorReport) -> None: # type: ignore[override] """Initialize dynamic fields.""" super().__post_init__(report) @@ -152,7 +157,7 @@ def timing(report: CrossValidationReport, label: str) -> float | None: return cast_to_float(series.iloc[0]) - def __post_init__(self, report: CrossValidationReport): # type: ignore[override] + def __post_init__(self, report: CrossValidationReport) -> None: # type: ignore[override] """Initialize dynamic fields.""" super().__post_init__(report) diff --git a/skore-local-project/src/skore_local_project/project.py b/skore-local-project/src/skore_local_project/project.py index fdd726bfa..5aad0a719 100644 --- a/skore-local-project/src/skore_local_project/project.py +++ b/skore-local-project/src/skore_local_project/project.py @@ -6,7 +6,7 @@ import os from functools import wraps from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ParamSpec, Protocol, TypeVar, cast, runtime_checkable from uuid import uuid4 import joblib @@ -15,7 +15,12 @@ from .metadata import CrossValidationReportMetadata, EstimatorReportMetadata from .storage import DiskCacheStorage +P = ParamSpec("P") +R = TypeVar("R") + + if TYPE_CHECKING: + from collections.abc import Callable from typing import TypedDict from skore import CrossValidationReport, EstimatorReport @@ -40,17 +45,26 @@ class Metadata(TypedDict): # noqa: D101 predict_time_mean: float | None -def ensure_project_is_not_deleted(method): +def ensure_project_is_not_deleted(method: Callable[P, R]) -> Callable[P, R]: """Ensure project is not deleted, before executing any other operation.""" + @runtime_checkable + class Project(Protocol): + name: str + _Project__projects_storage: DiskCacheStorage + @wraps(method) - def wrapper(self, *args, **kwargs): - if self.name not in self._Project__projects_storage: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + project = args[0] + + assert isinstance(project, Project), "You can only wrap `Project` methods" + + if project.name not in project._Project__projects_storage: raise RuntimeError( - f"Skore could not proceed because {self!r} does not exist anymore." + f"Skore could not proceed because {project!r} does not exist anymore." ) - return method(self, *args, **kwargs) + return method(*args, **kwargs) return wrapper @@ -183,7 +197,7 @@ def pickle(report: EstimatorReport | CrossValidationReport) -> tuple[str, bytes] return pickle_hash, pickle_bytes @ensure_project_is_not_deleted - def put(self, key: str, report: EstimatorReport | CrossValidationReport): + def put(self, key: str, report: EstimatorReport | CrossValidationReport) -> None: """ Put a key-report pair to the local project. @@ -238,7 +252,9 @@ def get(self, id: str) -> EstimatorReport | CrossValidationReport: """Get a persisted report by its id.""" if id in self.__artifacts_storage: with io.BytesIO(self.__artifacts_storage[id]) as stream: - return joblib.load(stream) + return cast( + "EstimatorReport | CrossValidationReport", joblib.load(stream) + ) raise KeyError(id) @@ -275,7 +291,7 @@ def __repr__(self) -> str: # noqa: D105 ) @staticmethod - def delete(name: str, *, workspace: Path | None = None): + def delete(name: str, *, workspace: Path | None = None) -> None: r""" Delete a local project. diff --git a/skore-local-project/src/skore_local_project/storage.py b/skore-local-project/src/skore_local_project/storage.py index f7b9df86a..4f2a5a2ec 100644 --- a/skore-local-project/src/skore_local_project/storage.py +++ b/skore-local-project/src/skore_local_project/storage.py @@ -106,7 +106,7 @@ def __getitem__(self, key: str) -> Any: with Cache(self.directory) as storage: return storage[key] - def __setitem__(self, key: str, value: Any): + def __setitem__(self, key: str, value: Any) -> None: """ Set an item in the storage. @@ -120,7 +120,7 @@ def __setitem__(self, key: str, value: Any): with Cache(self.directory) as storage: storage[key] = value - def __delitem__(self, key: str): + def __delitem__(self, key: str) -> None: """ Delete an item from the storage.