Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions skore-local-project/src/skore_local_project/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Class definition of the ``metadata`` objects used by ``project``."""

from abc import ABC
from contextlib import suppress
from dataclasses import InitVar, dataclass, field, fields
from datetime import datetime, timezone
from math import isfinite
from typing import Any

import joblib
from skore import CrossValidationReport, EstimatorReport

Report = EstimatorReport | CrossValidationReport


def cast_to_float(value: Any) -> float | None:
"""Cast value to float."""
with suppress(TypeError):
if isfinite(value := float(value)):
return value

return None


def report_type(report: Report):
"""Human readable type of a report."""
if isinstance(report, CrossValidationReport):
return "cross-validation"
if isinstance(report, EstimatorReport):
return "estimator"

raise TypeError


@dataclass(kw_only=True)
class ReportMetadata(ABC):
"""
Metadata used to persist a report to local storage.

Attributes
----------
report : EstimatorReport | CrossValidationReport
The report on which to calculate the metadata to persist.
artifact_id : str
ID of the artifact in the artifacts storage.
project_name : str
The name of the project the metadata should be associated with.
run_id : str
The run the metadata should be associated with.
date : str
The date the metadata were created.
key : str
The key to associate to the report.
learner : str
The name of the report's estimator.
ml_task : str
The type of ML task covered by the report.
report_type : str
The type of the report.
dataset : str
The hash of the targets.
"""

report: InitVar[Report]

artifact_id: str
project_name: str
run_id: str
key: str
date: str = field(init=False)
learner: str = field(init=False)
ml_task: str = field(init=False)
report_type: str = field(init=False)
dataset: str = field(init=False)

def __iter__(self):
"""Iterate over the metadata."""
for field in fields(self): # noqa: F402
yield (field.name, getattr(self, field.name))

def __post_init__(self, report: Report):
"""Initialize dynamic fields."""
self.date = datetime.now(timezone.utc).isoformat()
self.learner = report.estimator_name_
self.ml_task = report.ml_task
self.report_type = report_type(report)
self.dataset = joblib.hash(
report.y_test if hasattr(report, "y_test") else report.y
)


@dataclass(kw_only=True)
class EstimatorReportMetadata(ReportMetadata): # noqa: D101
rmse: float | None = field(init=False)
log_loss: float | None = field(init=False)
roc_auc: float | None = field(init=False)
fit_time: float | None = field(init=False)
predict_time: float | None = field(init=False)

@staticmethod
def metric(report: EstimatorReport, name: str) -> float | None:
"""Compute metric."""
if not hasattr(report.metrics, name):
return None

return cast_to_float(getattr(report.metrics, name)(data_source="test"))

def __post_init__(self, report: EstimatorReport): # type: ignore[override]
"""Initialize dynamic fields."""
super().__post_init__(report)

self.rmse = self.metric(report, "rmse")
self.log_loss = self.metric(report, "log_loss")
self.roc_auc = self.metric(report, "roc_auc")

# timings must be calculated last
self.fit_time = report.metrics.timings().get("fit_time")
self.predict_time = report.metrics.timings().get("predict_time_test")


@dataclass(kw_only=True)
class CrossValidationReportMetadata(ReportMetadata): # noqa: D101
rmse_mean: float | None = field(init=False)
log_loss_mean: float | None = field(init=False)
roc_auc_mean: float | None = field(init=False)
fit_time_mean: float | None = field(init=False)
predict_time_mean: float | None = field(init=False)

@staticmethod
def metric(report: CrossValidationReport, name: str) -> float | None:
"""Compute metric."""
if not hasattr(report.metrics, name):
return None

dataframe = getattr(report.metrics, name)(
data_source="test",
aggregate="mean",
)

return cast_to_float(dataframe.iloc[0, 0])

@staticmethod
def timing(report: CrossValidationReport, label: str) -> float | None:
"""Compute timing."""
dataframe = report.metrics.timings(aggregate="mean")

try:
series = dataframe.loc[label]
except KeyError:
return None

return cast_to_float(series.iloc[0])

def __post_init__(self, report: CrossValidationReport): # type: ignore[override]
"""Initialize dynamic fields."""
super().__post_init__(report)

self.rmse_mean = self.metric(report, "rmse")
self.log_loss_mean = self.metric(report, "log_loss")
self.roc_auc_mean = self.metric(report, "roc_auc")

# timings must be calculated last
self.fit_time_mean = self.timing(report, "Fit time (s)")
self.predict_time_mean = self.timing(report, "Predict time test (s)")
90 changes: 27 additions & 63 deletions skore-local-project/src/skore_local_project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import io
import os
from contextlib import suppress
from datetime import datetime, timezone
from operator import itemgetter
from pathlib import Path
from types import SimpleNamespace
Expand All @@ -14,29 +12,14 @@

import joblib
import platformdirs
from skore import CrossValidationReport, EstimatorReport

from .metadata import CrossValidationReportMetadata, EstimatorReportMetadata
from .storage import DiskCacheStorage

if TYPE_CHECKING:
from typing import TypedDict

from skore import EstimatorReport

class PersistedMetadata: # noqa: D101
artifact_id: str
project_name: str
run_id: str
key: str
date: str
learner: str
dataset: str
ml_task: str
rmse: float | None
log_loss: float | None
roc_auc: float | None
fit_time: float
predict_time: float

class Metadata(TypedDict): # noqa: D101
id: str
run_id: str
Expand Down Expand Up @@ -162,30 +145,30 @@ def workspace(self) -> Path:
return self.__workspace

@staticmethod
def pickle(report: EstimatorReport) -> tuple[str, bytes]:
def pickle(report: EstimatorReport | CrossValidationReport) -> tuple[str, bytes]:
"""
Pickle ``report``, return the bytes and the corresponding hash.

Notes
-----
The report is pickled without its cache, to avoid salting the hash.
"""
cache = report._cache
reports = [report] + getattr(report, "estimator_reports_", [])
caches = [report_to_clear.__dict__.pop("_cache") for report_to_clear in reports]

try:
report._cache = {}

with io.BytesIO() as stream:
joblib.dump(report, stream)

pickle_bytes = stream.getvalue()
pickle_hash = joblib.hash(pickle_bytes)
finally:
report._cache = cache
for report, cache in zip(reports, caches, strict=True):
report._cache = cache

return pickle_hash, pickle_bytes

def put(self, key: str, report: EstimatorReport):
def put(self, key: str, report: EstimatorReport | CrossValidationReport):
"""
Put a key-report pair to the local project.

Expand All @@ -196,7 +179,7 @@ def put(self, key: str, report: EstimatorReport):
----------
key : str
The key to associate with ``report`` in the local project.
report : skore.EstimatorReport
report : skore.EstimatorReport | skore.CrossValidationReport
The report to associate with ``key`` in the local project.

Raises
Expand All @@ -207,11 +190,16 @@ def put(self, key: str, report: EstimatorReport):
if not isinstance(key, str):
raise TypeError(f"Key must be a string (found '{type(key)}')")

from skore import EstimatorReport
Metadata: type[EstimatorReportMetadata] | type[CrossValidationReportMetadata]

if not isinstance(report, EstimatorReport):
if isinstance(report, EstimatorReport):
Metadata = EstimatorReportMetadata
elif isinstance(report, CrossValidationReport):
Metadata = CrossValidationReportMetadata
else:
raise TypeError(
f"Report must be a `skore.EstimatorReport` (found '{type(report)}')"
f"Report must be a `skore.EstimatorReport` or `skore.CrossValidationRep"
f"ort` (found '{type(report)}')"
)

if self.name not in self.__projects_storage:
Expand All @@ -225,39 +213,15 @@ def put(self, key: str, report: EstimatorReport):
if pickle_hash not in self.__artifacts_storage:
self.__artifacts_storage[pickle_hash] = pickle_bytes

def metric(name):
"""
Compute ``report.metrics.name``.

Notes
-----
Unavailable metrics return None.

All metrics whose report is not a scalar return None:
- ignore ``list[float]`` for multi-output ML task,
- ignore ``dict[str: float]`` for multi-classes ML task.
"""
if hasattr(report.metrics, name):
with suppress(TypeError):
return float(getattr(report.metrics, name)(data_source="test"))
return None

self.__metadata_storage[uuid4().hex] = {
"project_name": self.name,
"run_id": self.run_id,
"key": key,
"artifact_id": pickle_hash,
"date": datetime.now(timezone.utc).isoformat(),
"learner": report.estimator_name_,
"dataset": joblib.hash(report.y_test),
"ml_task": report._ml_task,
"rmse": metric("rmse"),
"log_loss": metric("log_loss"),
"roc_auc": metric("roc_auc"),
# timings must be calculated last
"fit_time": report.metrics.timings().get("fit_time"),
"predict_time": report.metrics.timings().get("predict_time_test"),
}
self.__metadata_storage[uuid4().hex] = dict(
Metadata(
report=report,
artifact_id=pickle_hash,
project_name=self.name,
run_id=self.run_id,
key=key,
)
)

@property
def reports(self):
Expand All @@ -277,7 +241,7 @@ def get(id: str) -> EstimatorReport:
raise KeyError(id)

def metadata() -> list[Metadata]:
"""Obtain metadata for all persisted reports regardless of their run."""
"""Obtain metadata/metrics for all persisted reports."""
return sorted(
(
{
Expand Down
Loading
Loading