Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions skore-local-project/src/skore_local_project/metadata.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
"""Class definition of the ``metadata`` objects used by ``project``."""

from __future__ import annotations

from abc import ABC
from contextlib import suppress
from dataclasses import InitVar, dataclass, field, fields
from datetime import datetime, timezone
from math import isfinite
from typing import Any
from typing import TYPE_CHECKING

from joblib import hash

import joblib
from skore import CrossValidationReport, EstimatorReport
if TYPE_CHECKING:
from typing import Any

Report = EstimatorReport | CrossValidationReport
from skore import CrossValidationReport, EstimatorReport


def cast_to_float(value: Any) -> float | None:
Expand All @@ -22,8 +26,10 @@ def cast_to_float(value: Any) -> float | None:
return None


def report_type(report: Report):
def report_type(report: EstimatorReport | CrossValidationReport):
"""Human readable type of a report."""
from skore import CrossValidationReport, EstimatorReport

if isinstance(report, CrossValidationReport):
return "cross-validation"
if isinstance(report, EstimatorReport):
Expand Down Expand Up @@ -61,7 +67,7 @@ class ReportMetadata(ABC):
The hash of the targets.
"""

report: InitVar[Report]
report: InitVar[EstimatorReport | CrossValidationReport]

artifact_id: str
project_name: str
Expand All @@ -78,15 +84,13 @@ def __iter__(self):
for field in fields(self): # noqa: F402
yield (field.name, getattr(self, field.name))

def __post_init__(self, report: Report):
def __post_init__(self, report: EstimatorReport | CrossValidationReport):
"""Initialize dynamic fields."""
self.date = datetime.now(timezone.utc).isoformat()
self.learner = report.estimator_name_
self.ml_task = report.ml_task
self.report_type = report_type(report)
self.dataset = joblib.hash(
report.y_test if hasattr(report, "y_test") else report.y
)
self.dataset = hash(report.y_test if hasattr(report, "y_test") else report.y)


@dataclass(kw_only=True)
Expand Down
100 changes: 66 additions & 34 deletions skore-local-project/src/skore_local_project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,41 @@
import io
import os
from functools import wraps
from operator import itemgetter
from pathlib import Path
from types import SimpleNamespace
from typing import TYPE_CHECKING
from uuid import uuid4

import joblib
import platformdirs
from skore import CrossValidationReport, EstimatorReport

from .metadata import CrossValidationReportMetadata, EstimatorReportMetadata
from .storage import DiskCacheStorage

if TYPE_CHECKING:
from typing import TypedDict

from skore import CrossValidationReport, EstimatorReport

class Metadata(TypedDict): # noqa: D101
id: str
run_id: str
key: str
date: str
learner: str
dataset: str
ml_task: str
report_type: str
dataset: str
rmse: float | None
log_loss: float | None
roc_auc: float | None
fit_time: float
predict_time: float
fit_time: float | None
predict_time: float | None
rmse_mean: float | None
log_loss_mean: float | None
roc_auc_mean: float | None
fit_time_mean: float | None
predict_time_mean: float | None


def ensure_project_is_not_deleted(method):
Expand Down Expand Up @@ -204,6 +210,8 @@ def put(self, key: str, report: EstimatorReport | CrossValidationReport):
TypeError
If the combination of parameters are not valid.
"""
from skore import CrossValidationReport, EstimatorReport

if not isinstance(key, str):
raise TypeError(f"Key must be a string (found '{type(key)}')")

Expand Down Expand Up @@ -234,42 +242,66 @@ def put(self, key: str, report: EstimatorReport | CrossValidationReport):
)
)

@property
@ensure_project_is_not_deleted
def get(self, id: str) -> EstimatorReport | CrossValidationReport:
"""Get a persisted report by its id."""
if id in self.__artifacts_storage:
with io.BytesIO(self.__artifacts_storage[id]) as stream:
return joblib.load(stream)

raise KeyError(id)

@ensure_project_is_not_deleted
def summarize(self) -> list[Metadata]:
"""Obtain metadata/metrics for all persisted reports in insertion order."""
return [
{
"id": value["artifact_id"],
"run_id": value["run_id"],
"key": value["key"],
"date": value["date"],
"learner": value["learner"],
"ml_task": value["ml_task"],
"report_type": value["report_type"],
"dataset": value["dataset"],
"rmse": value.get("rmse"),
"log_loss": value.get("log_loss"),
"roc_auc": value.get("roc_auc"),
"fit_time": value.get("fit_time"),
"predict_time": value.get("predict_time"),
"rmse_mean": value.get("rmse_mean"),
"log_loss_mean": value.get("log_loss_mean"),
"roc_auc_mean": value.get("roc_auc_mean"),
"fit_time_mean": value.get("fit_time_mean"),
"predict_time_mean": value.get("predict_time_mean"),
}
for value in self.__metadata_storage.values()
if value["project_name"] == self.name
]

@property
def reports(self):
"""Accessor for interaction with the persisted reports."""

def get(id: str) -> EstimatorReport:
"""Get a persisted report by its id."""
if id in self.__artifacts_storage:
with io.BytesIO(self.__artifacts_storage[id]) as stream:
return joblib.load(stream)
def get(id: str) -> EstimatorReport | CrossValidationReport:
"""
Get a persisted report by its id.

raise KeyError(id)
.. deprecated
The ``Project.reports.get`` function will be removed in favor of
``Project.get`` in a near future.
"""
return self.get(id)

def metadata() -> list[Metadata]:
"""Obtain metadata/metrics for all persisted reports."""
return sorted(
(
{
"id": value["artifact_id"],
"run_id": value["run_id"],
"key": value["key"],
"date": value["date"],
"learner": value["learner"],
"dataset": value["dataset"],
"ml_task": value["ml_task"],
"rmse": value["rmse"],
"log_loss": value["log_loss"],
"roc_auc": value["roc_auc"],
"fit_time": value["fit_time"],
"predict_time": value["predict_time"],
}
for value in self.__metadata_storage.values()
if value["project_name"] == self.name
),
key=itemgetter("date"),
)
"""
Obtain metadata/metrics for all persisted reports in insertion order.

.. deprecated
The ``Project.reports.metadata`` function will be removed in favor of
``Project.summarize`` in a near future.
"""
return self.summarize()

return SimpleNamespace(get=get, metadata=metadata)

Expand Down
11 changes: 11 additions & 0 deletions skore-local-project/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,22 @@ def nowstr(now):

@fixture
def Datetime(now):
now_from_fixture = now

class Datetime:
nows = []
nows_isoformat = []

def __init__(self, *args, **kwargs): ...

@staticmethod
def now(*args, **kwargs):
now = datetime.now(tz=timezone.utc) if Datetime.nows else now_from_fixture
now_isoformat = now.isoformat()

Datetime.nows.append(now)
Datetime.nows_isoformat.append(now_isoformat)

return now

return Datetime
Expand Down
109 changes: 80 additions & 29 deletions skore-local-project/tests/unit/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,20 @@ def test_reports(self, tmp_path):
assert hasattr(project.reports, "get")
assert hasattr(project.reports, "metadata")

def test_reports_exception(self, tmp_path):
def test_reports_get(self, tmp_path, regression):
project = Project("<project>", workspace=tmp_path)
project.put("<key>", regression)
project.put("<key>", regression)

report = project.reports.get(next(project._Project__artifacts_storage.keys()))

assert len(project._Project__artifacts_storage) == 1
assert len(project._Project__metadata_storage) == 2
assert isinstance(report, EstimatorReport)
assert report.estimator_name_ == regression.estimator_name_
assert report._ml_task == regression._ml_task

def test_reports_get_exception(self, tmp_path, regression):
import re

project = Project("<project>", workspace=tmp_path)
Expand All @@ -302,60 +315,98 @@ def test_reports_exception(self, tmp_path):
f"does not exist anymore."
),
):
project.reports # noqa: B018
project.reports.get(None)

def test_reports_get(self, tmp_path, regression):
def test_reports_metadata(self, tmp_path, Datetime, regression, cv_regression):
project = Project("<project>", workspace=tmp_path)
project.put("<key>", regression)
project.put("<key>", regression)

report = project.reports.get(next(project._Project__artifacts_storage.keys()))

assert len(project._Project__artifacts_storage) == 1
assert len(project._Project__metadata_storage) == 2
assert isinstance(report, EstimatorReport)
assert report.estimator_name_ == regression.estimator_name_
assert report._ml_task == regression._ml_task
project.put("<key1>", regression)
project.put("<key1>", regression)
project.put("<key2>", cv_regression)

def test_reports_metadata(self, tmp_path, nowstr, regression):
project = Project("<project>", workspace=tmp_path)

project.put("<key>", regression)
project.put("<key>", regression)
artifact_ids = list(project._Project__artifacts_storage.keys())

assert len(project._Project__artifacts_storage) == 1
assert len(project._Project__metadata_storage) == 2
assert len(project._Project__artifacts_storage) == 2
assert len(project._Project__metadata_storage) == 3
assert project.reports.metadata() == [
{
"id": next(project._Project__artifacts_storage.keys()),
"id": artifact_ids[0],
"run_id": project.run_id,
"key": "<key>",
"date": nowstr,
"learner": regression.estimator_name_,
"key": "<key1>",
"date": Datetime.nows_isoformat[0],
"learner": "Ridge",
"ml_task": "regression",
"report_type": "estimator",
"dataset": joblib.hash(regression.y_test),
"ml_task": regression._ml_task,
"rmse": float(hash("<rmse_test>")),
"log_loss": None,
"roc_auc": None,
"fit_time": float(hash("<fit_time>")),
"predict_time": float(hash("<predict_time_test>")),
"rmse_mean": None,
"log_loss_mean": None,
"roc_auc_mean": None,
"fit_time_mean": None,
"predict_time_mean": None,
},
{
"id": next(project._Project__artifacts_storage.keys()),
"id": artifact_ids[0],
"run_id": project.run_id,
"key": "<key>",
"date": nowstr,
"learner": regression.estimator_name_,
"key": "<key1>",
"date": Datetime.nows_isoformat[1],
"learner": "Ridge",
"ml_task": "regression",
"report_type": "estimator",
"dataset": joblib.hash(regression.y_test),
"ml_task": regression._ml_task,
"rmse": float(hash("<rmse_test>")),
"log_loss": None,
"roc_auc": None,
"fit_time": float(hash("<fit_time>")),
"predict_time": float(hash("<predict_time_test>")),
"rmse_mean": None,
"log_loss_mean": None,
"roc_auc_mean": None,
"fit_time_mean": None,
"predict_time_mean": None,
},
{
"id": artifact_ids[1],
"run_id": project.run_id,
"key": "<key2>",
"date": Datetime.nows_isoformat[2],
"learner": "Ridge",
"ml_task": "regression",
"report_type": "cross-validation",
"dataset": joblib.hash(cv_regression.y),
"rmse": None,
"log_loss": None,
"roc_auc": None,
"fit_time": None,
"predict_time": None,
"rmse_mean": float(hash("<rmse_mean_test>")),
"log_loss_mean": None,
"roc_auc_mean": None,
"fit_time_mean": float(hash("<fit_time_mean>")),
"predict_time_mean": float(hash("<predict_time_mean_test>")),
},
]

def test_reports_metadata_exception(self, tmp_path, regression):
import re

project = Project("<project>", workspace=tmp_path)
Project.delete("<project>", workspace=tmp_path)

with raises(
RuntimeError,
match=re.escape(
f"Skore could not proceed because "
f"Project(mode='local', name='<project>', workspace='{tmp_path}') "
f"does not exist anymore."
),
):
project.reports.metadata()

def test_delete(self, tmp_path, binary_classification, regression):
project1 = Project("<project1>", workspace=tmp_path)
project1.put("<project1-key1>", binary_classification)
Expand Down
Loading