diff --git a/skore-local-project/src/skore_local_project/metadata.py b/skore-local-project/src/skore_local_project/metadata.py index a25bad3ee7..4350eca4b6 100644 --- a/skore-local-project/src/skore_local_project/metadata.py +++ b/skore-local-project/src/skore_local_project/metadata.py @@ -1,16 +1,20 @@ """Class definition of the ``metadata`` objects used by ``project``.""" +from __future__ import annotations + from abc import ABC from contextlib import suppress from dataclasses import InitVar, dataclass, field, fields from datetime import datetime, timezone from math import isfinite -from typing import Any +from typing import TYPE_CHECKING + +from joblib import hash -import joblib -from skore import CrossValidationReport, EstimatorReport +if TYPE_CHECKING: + from typing import Any -Report = EstimatorReport | CrossValidationReport + from skore import CrossValidationReport, EstimatorReport def cast_to_float(value: Any) -> float | None: @@ -22,8 +26,10 @@ def cast_to_float(value: Any) -> float | None: return None -def report_type(report: Report): +def report_type(report: EstimatorReport | CrossValidationReport): """Human readable type of a report.""" + from skore import CrossValidationReport, EstimatorReport + if isinstance(report, CrossValidationReport): return "cross-validation" if isinstance(report, EstimatorReport): @@ -61,7 +67,7 @@ class ReportMetadata(ABC): The hash of the targets. """ - report: InitVar[Report] + report: InitVar[EstimatorReport | CrossValidationReport] artifact_id: str project_name: str @@ -78,15 +84,13 @@ def __iter__(self): for field in fields(self): # noqa: F402 yield (field.name, getattr(self, field.name)) - def __post_init__(self, report: Report): + def __post_init__(self, report: EstimatorReport | CrossValidationReport): """Initialize dynamic fields.""" self.date = datetime.now(timezone.utc).isoformat() self.learner = report.estimator_name_ self.ml_task = report.ml_task self.report_type = report_type(report) - self.dataset = joblib.hash( - report.y_test if hasattr(report, "y_test") else report.y - ) + self.dataset = hash(report.y_test if hasattr(report, "y_test") else report.y) @dataclass(kw_only=True) diff --git a/skore-local-project/src/skore_local_project/project.py b/skore-local-project/src/skore_local_project/project.py index cd2b707270..afd9a00ff9 100644 --- a/skore-local-project/src/skore_local_project/project.py +++ b/skore-local-project/src/skore_local_project/project.py @@ -5,7 +5,6 @@ import io import os from functools import wraps -from operator import itemgetter from pathlib import Path from types import SimpleNamespace from typing import TYPE_CHECKING @@ -13,7 +12,6 @@ import joblib import platformdirs -from skore import CrossValidationReport, EstimatorReport from .metadata import CrossValidationReportMetadata, EstimatorReportMetadata from .storage import DiskCacheStorage @@ -21,19 +19,27 @@ if TYPE_CHECKING: from typing import TypedDict + from skore import CrossValidationReport, EstimatorReport + class Metadata(TypedDict): # noqa: D101 id: str run_id: str key: str date: str learner: str - dataset: str ml_task: str + report_type: str + dataset: str rmse: float | None log_loss: float | None roc_auc: float | None - fit_time: float - predict_time: float + fit_time: float | None + predict_time: float | None + rmse_mean: float | None + log_loss_mean: float | None + roc_auc_mean: float | None + fit_time_mean: float | None + predict_time_mean: float | None def ensure_project_is_not_deleted(method): @@ -204,6 +210,8 @@ def put(self, key: str, report: EstimatorReport | CrossValidationReport): TypeError If the combination of parameters are not valid. """ + from skore import CrossValidationReport, EstimatorReport + if not isinstance(key, str): raise TypeError(f"Key must be a string (found '{type(key)}')") @@ -234,42 +242,66 @@ def put(self, key: str, report: EstimatorReport | CrossValidationReport): ) ) - @property @ensure_project_is_not_deleted + def get(self, id: str) -> EstimatorReport | CrossValidationReport: + """Get a persisted report by its id.""" + if id in self.__artifacts_storage: + with io.BytesIO(self.__artifacts_storage[id]) as stream: + return joblib.load(stream) + + raise KeyError(id) + + @ensure_project_is_not_deleted + def summarize(self) -> list[Metadata]: + """Obtain metadata/metrics for all persisted reports in insertion order.""" + return [ + { + "id": value["artifact_id"], + "run_id": value["run_id"], + "key": value["key"], + "date": value["date"], + "learner": value["learner"], + "ml_task": value["ml_task"], + "report_type": value["report_type"], + "dataset": value["dataset"], + "rmse": value.get("rmse"), + "log_loss": value.get("log_loss"), + "roc_auc": value.get("roc_auc"), + "fit_time": value.get("fit_time"), + "predict_time": value.get("predict_time"), + "rmse_mean": value.get("rmse_mean"), + "log_loss_mean": value.get("log_loss_mean"), + "roc_auc_mean": value.get("roc_auc_mean"), + "fit_time_mean": value.get("fit_time_mean"), + "predict_time_mean": value.get("predict_time_mean"), + } + for value in self.__metadata_storage.values() + if value["project_name"] == self.name + ] + + @property def reports(self): """Accessor for interaction with the persisted reports.""" - def get(id: str) -> EstimatorReport: - """Get a persisted report by its id.""" - if id in self.__artifacts_storage: - with io.BytesIO(self.__artifacts_storage[id]) as stream: - return joblib.load(stream) + def get(id: str) -> EstimatorReport | CrossValidationReport: + """ + Get a persisted report by its id. - raise KeyError(id) + .. deprecated + The ``Project.reports.get`` function will be removed in favor of + ``Project.get`` in a near future. + """ + return self.get(id) def metadata() -> list[Metadata]: - """Obtain metadata/metrics for all persisted reports.""" - return sorted( - ( - { - "id": value["artifact_id"], - "run_id": value["run_id"], - "key": value["key"], - "date": value["date"], - "learner": value["learner"], - "dataset": value["dataset"], - "ml_task": value["ml_task"], - "rmse": value["rmse"], - "log_loss": value["log_loss"], - "roc_auc": value["roc_auc"], - "fit_time": value["fit_time"], - "predict_time": value["predict_time"], - } - for value in self.__metadata_storage.values() - if value["project_name"] == self.name - ), - key=itemgetter("date"), - ) + """ + Obtain metadata/metrics for all persisted reports in insertion order. + + .. deprecated + The ``Project.reports.metadata`` function will be removed in favor of + ``Project.summarize`` in a near future. + """ + return self.summarize() return SimpleNamespace(get=get, metadata=metadata) diff --git a/skore-local-project/tests/conftest.py b/skore-local-project/tests/conftest.py index a616a4b4ce..9476bf75ab 100644 --- a/skore-local-project/tests/conftest.py +++ b/skore-local-project/tests/conftest.py @@ -15,11 +15,22 @@ def nowstr(now): @fixture def Datetime(now): + now_from_fixture = now + class Datetime: + nows = [] + nows_isoformat = [] + def __init__(self, *args, **kwargs): ... @staticmethod def now(*args, **kwargs): + now = datetime.now(tz=timezone.utc) if Datetime.nows else now_from_fixture + now_isoformat = now.isoformat() + + Datetime.nows.append(now) + Datetime.nows_isoformat.append(now_isoformat) + return now return Datetime diff --git a/skore-local-project/tests/unit/test_project.py b/skore-local-project/tests/unit/test_project.py index 918422172f..ed7be545dd 100644 --- a/skore-local-project/tests/unit/test_project.py +++ b/skore-local-project/tests/unit/test_project.py @@ -288,7 +288,20 @@ def test_reports(self, tmp_path): assert hasattr(project.reports, "get") assert hasattr(project.reports, "metadata") - def test_reports_exception(self, tmp_path): + def test_reports_get(self, tmp_path, regression): + project = Project("", workspace=tmp_path) + project.put("", regression) + project.put("", regression) + + report = project.reports.get(next(project._Project__artifacts_storage.keys())) + + assert len(project._Project__artifacts_storage) == 1 + assert len(project._Project__metadata_storage) == 2 + assert isinstance(report, EstimatorReport) + assert report.estimator_name_ == regression.estimator_name_ + assert report._ml_task == regression._ml_task + + def test_reports_get_exception(self, tmp_path, regression): import re project = Project("", workspace=tmp_path) @@ -302,60 +315,98 @@ def test_reports_exception(self, tmp_path): f"does not exist anymore." ), ): - project.reports # noqa: B018 + project.reports.get(None) - def test_reports_get(self, tmp_path, regression): + def test_reports_metadata(self, tmp_path, Datetime, regression, cv_regression): project = Project("", workspace=tmp_path) - project.put("", regression) - project.put("", regression) - - report = project.reports.get(next(project._Project__artifacts_storage.keys())) - assert len(project._Project__artifacts_storage) == 1 - assert len(project._Project__metadata_storage) == 2 - assert isinstance(report, EstimatorReport) - assert report.estimator_name_ == regression.estimator_name_ - assert report._ml_task == regression._ml_task + project.put("", regression) + project.put("", regression) + project.put("", cv_regression) - def test_reports_metadata(self, tmp_path, nowstr, regression): - project = Project("", workspace=tmp_path) - - project.put("", regression) - project.put("", regression) + artifact_ids = list(project._Project__artifacts_storage.keys()) - assert len(project._Project__artifacts_storage) == 1 - assert len(project._Project__metadata_storage) == 2 + assert len(project._Project__artifacts_storage) == 2 + assert len(project._Project__metadata_storage) == 3 assert project.reports.metadata() == [ { - "id": next(project._Project__artifacts_storage.keys()), + "id": artifact_ids[0], "run_id": project.run_id, - "key": "", - "date": nowstr, - "learner": regression.estimator_name_, + "key": "", + "date": Datetime.nows_isoformat[0], + "learner": "Ridge", + "ml_task": "regression", + "report_type": "estimator", "dataset": joblib.hash(regression.y_test), - "ml_task": regression._ml_task, "rmse": float(hash("")), "log_loss": None, "roc_auc": None, "fit_time": float(hash("")), "predict_time": float(hash("")), + "rmse_mean": None, + "log_loss_mean": None, + "roc_auc_mean": None, + "fit_time_mean": None, + "predict_time_mean": None, }, { - "id": next(project._Project__artifacts_storage.keys()), + "id": artifact_ids[0], "run_id": project.run_id, - "key": "", - "date": nowstr, - "learner": regression.estimator_name_, + "key": "", + "date": Datetime.nows_isoformat[1], + "learner": "Ridge", + "ml_task": "regression", + "report_type": "estimator", "dataset": joblib.hash(regression.y_test), - "ml_task": regression._ml_task, "rmse": float(hash("")), "log_loss": None, "roc_auc": None, "fit_time": float(hash("")), "predict_time": float(hash("")), + "rmse_mean": None, + "log_loss_mean": None, + "roc_auc_mean": None, + "fit_time_mean": None, + "predict_time_mean": None, + }, + { + "id": artifact_ids[1], + "run_id": project.run_id, + "key": "", + "date": Datetime.nows_isoformat[2], + "learner": "Ridge", + "ml_task": "regression", + "report_type": "cross-validation", + "dataset": joblib.hash(cv_regression.y), + "rmse": None, + "log_loss": None, + "roc_auc": None, + "fit_time": None, + "predict_time": None, + "rmse_mean": float(hash("")), + "log_loss_mean": None, + "roc_auc_mean": None, + "fit_time_mean": float(hash("")), + "predict_time_mean": float(hash("")), }, ] + def test_reports_metadata_exception(self, tmp_path, regression): + import re + + project = Project("", workspace=tmp_path) + Project.delete("", workspace=tmp_path) + + with raises( + RuntimeError, + match=re.escape( + f"Skore could not proceed because " + f"Project(mode='local', name='', workspace='{tmp_path}') " + f"does not exist anymore." + ), + ): + project.reports.metadata() + def test_delete(self, tmp_path, binary_classification, regression): project1 = Project("", workspace=tmp_path) project1.put("", binary_classification)