Skip to content

Commit 8c79262

Browse files
committed
feat(skore-hub-project): Thread the compute of report metrics
1 parent ad39cff commit 8c79262

File tree

18 files changed

+865
-1170
lines changed

18 files changed

+865
-1170
lines changed

skore-hub-project/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies = [
1212
"numpy",
1313
"orjson",
1414
"pydantic",
15-
"rich",
15+
"rich>=14.2.0",
1616
"scikit-learn",
1717
]
1818

skore-hub-project/src/skore_hub_project/metric/metric.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from abc import ABC, abstractmethod
66
from collections.abc import Callable
77
from contextlib import suppress
8-
from functools import cached_property, reduce
8+
from functools import reduce
99
from math import isfinite
1010
from typing import (
1111
TYPE_CHECKING,
@@ -59,7 +59,7 @@ class Metric(BaseModel, ABC, Generic[Report]):
5959
default None to disable its display.
6060
"""
6161

62-
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
62+
model_config = ConfigDict(arbitrary_types_allowed=True)
6363

6464
report: Report = Field(repr=False, exclude=True)
6565
name: str = Field(init=False)
@@ -70,9 +70,26 @@ class Metric(BaseModel, ABC, Generic[Report]):
7070

7171
@computed_field # type: ignore[prop-decorator]
7272
@property
73-
@abstractmethod
7473
def value(self) -> float | None:
7574
"""The value of the metric."""
75+
try:
76+
return self.__value
77+
except AttributeError:
78+
message = (
79+
"You cannot access the value of a metric "
80+
"without explicitly calculating it. "
81+
"Please use `metric.compute()` before."
82+
)
83+
84+
raise RuntimeError(message) from None
85+
86+
@value.setter
87+
def value(self, value: float | None) -> None:
88+
self.__value = value
89+
90+
@abstractmethod
91+
def compute(self) -> None:
92+
"""Compute the value of the metric."""
7693

7794

7895
class EstimatorReportMetric(Metric[EstimatorReport]):
@@ -100,19 +117,17 @@ class EstimatorReportMetric(Metric[EstimatorReport]):
100117

101118
accessor: ClassVar[str]
102119

103-
@computed_field # type: ignore[prop-decorator]
104-
@cached_property
105-
def value(self) -> float | None:
106-
"""The value of the metric."""
120+
def compute(self) -> None:
121+
"""Compute the value of the metric."""
107122
try:
108123
function = cast(
109124
Callable[..., float | None],
110125
reduce(getattr, self.accessor.split("."), self.report),
111126
)
112127
except AttributeError:
113-
return None
114-
115-
return cast_to_float(function(data_source=self.data_source))
128+
self.value = None
129+
else:
130+
self.value = cast_to_float(function(data_source=self.data_source))
116131

117132

118133
class CrossValidationReportMetric(Metric[CrossValidationReport]):
@@ -143,18 +158,15 @@ class CrossValidationReportMetric(Metric[CrossValidationReport]):
143158
accessor: ClassVar[str]
144159
aggregate: ClassVar[Literal["mean", "std"]]
145160

146-
@computed_field # type: ignore[prop-decorator]
147-
@cached_property
148-
def value(self) -> float | None:
149-
"""The value of the metric."""
161+
def compute(self) -> None:
162+
"""Compute the value of the metric."""
150163
try:
151164
function = cast(
152165
"Callable[..., DataFrame]",
153166
reduce(getattr, self.accessor.split("."), self.report),
154167
)
155168
except AttributeError:
156-
return None
157-
158-
dataframe = function(data_source=self.data_source, aggregate=self.aggregate)
159-
160-
return cast_to_float(dataframe.iloc[0, 0])
169+
self.value = None
170+
else:
171+
dataframe = function(data_source=self.data_source, aggregate=self.aggregate)
172+
self.value = cast_to_float(dataframe.iloc[0, 0])

skore-hub-project/src/skore_hub_project/metric/precision.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@
22

33
from __future__ import annotations
44

5-
from functools import cached_property
65
from typing import ClassVar, Literal
76

8-
from pydantic import computed_field
9-
107
from .metric import CrossValidationReportMetric, EstimatorReportMetric, cast_to_float
118

129

@@ -17,15 +14,16 @@ class Precision(EstimatorReportMetric): # noqa: D101
1714
greater_is_better: bool = True
1815
position: None = None
1916

20-
@computed_field # type: ignore[prop-decorator]
21-
@cached_property
22-
def value(self) -> float | None: # noqa: D102
17+
def compute(self) -> None:
18+
"""Compute the value of the metric."""
2319
try:
2420
function = self.report.metrics.precision
2521
except AttributeError:
26-
return None
27-
28-
return cast_to_float(function(data_source=self.data_source, average="macro"))
22+
self.value = None
23+
else:
24+
self.value = cast_to_float(
25+
function(data_source=self.data_source, average="macro")
26+
)
2927

3028

3129
class PrecisionTrain(Precision): # noqa: D101
@@ -43,19 +41,17 @@ class PrecisionMean(CrossValidationReportMetric): # noqa: D101
4341
greater_is_better: bool = True
4442
position: None = None
4543

46-
@computed_field # type: ignore[prop-decorator]
47-
@cached_property
48-
def value(self) -> float | None: # noqa: D102
44+
def compute(self) -> None:
45+
"""Compute the value of the metric."""
4946
try:
5047
function = self.report.metrics.precision
5148
except AttributeError:
52-
return None
53-
54-
dataframe = function(
55-
data_source=self.data_source, aggregate="mean", average="macro"
56-
)
57-
58-
return cast_to_float(dataframe.iloc[0, 0])
49+
self.value = None
50+
else:
51+
dataframe = function(
52+
data_source=self.data_source, aggregate="mean", average="macro"
53+
)
54+
self.value = cast_to_float(dataframe.iloc[0, 0])
5955

6056

6157
class PrecisionTrainMean(PrecisionMean): # noqa: D101
@@ -73,19 +69,17 @@ class PrecisionStd(CrossValidationReportMetric): # noqa: D101
7369
greater_is_better: bool = False
7470
position: None = None
7571

76-
@computed_field # type: ignore[prop-decorator]
77-
@cached_property
78-
def value(self) -> float | None: # noqa: D102
72+
def compute(self) -> None:
73+
"""Compute the value of the metric."""
7974
try:
8075
function = self.report.metrics.precision
8176
except AttributeError:
82-
return None
83-
84-
dataframe = function(
85-
data_source=self.data_source, aggregate="std", average="macro"
86-
)
87-
88-
return cast_to_float(dataframe.iloc[0, 0])
77+
self.value = None
78+
else:
79+
dataframe = function(
80+
data_source=self.data_source, aggregate="std", average="macro"
81+
)
82+
self.value = cast_to_float(dataframe.iloc[0, 0])
8983

9084

9185
class PrecisionTrainStd(PrecisionStd): # noqa: D101

skore-hub-project/src/skore_hub_project/metric/recall.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@
22

33
from __future__ import annotations
44

5-
from functools import cached_property
65
from typing import ClassVar, Literal
76

8-
from pydantic import computed_field
9-
107
from .metric import CrossValidationReportMetric, EstimatorReportMetric, cast_to_float
118

129

@@ -17,15 +14,16 @@ class Recall(EstimatorReportMetric): # noqa: D101
1714
greater_is_better: bool = True
1815
position: None = None
1916

20-
@computed_field # type: ignore[prop-decorator]
21-
@cached_property
22-
def value(self) -> float | None: # noqa: D102
17+
def compute(self) -> None:
18+
"""Compute the value of the metric."""
2319
try:
2420
function = self.report.metrics.recall
2521
except AttributeError:
26-
return None
27-
28-
return cast_to_float(function(data_source=self.data_source, average="macro"))
22+
self.value = None
23+
else:
24+
self.value = cast_to_float(
25+
function(data_source=self.data_source, average="macro")
26+
)
2927

3028

3129
class RecallTrain(Recall): # noqa: D101
@@ -44,19 +42,17 @@ class RecallMean(CrossValidationReportMetric): # noqa: D101
4442
greater_is_better: bool = True
4543
position: None = None
4644

47-
@computed_field # type: ignore[prop-decorator]
48-
@cached_property
49-
def value(self) -> float | None: # noqa: D102
45+
def compute(self) -> None:
46+
"""Compute the value of the metric."""
5047
try:
5148
function = self.report.metrics.recall
5249
except AttributeError:
53-
return None
54-
55-
dataframe = function(
56-
data_source=self.data_source, aggregate=self.aggregate, average="macro"
57-
)
58-
59-
return cast_to_float(dataframe.iloc[0, 0])
50+
self.value = None
51+
else:
52+
dataframe = function(
53+
data_source=self.data_source, aggregate=self.aggregate, average="macro"
54+
)
55+
self.value = cast_to_float(dataframe.iloc[0, 0])
6056

6157

6258
class RecallTrainMean(RecallMean): # noqa: D101
@@ -75,19 +71,17 @@ class RecallStd(CrossValidationReportMetric): # noqa: D101
7571
greater_is_better: bool = False
7672
position: None = None
7773

78-
@computed_field # type: ignore[prop-decorator]
79-
@cached_property
80-
def value(self) -> float | None: # noqa: D102
74+
def compute(self) -> None:
75+
"""Compute the value of the metric."""
8176
try:
8277
function = self.report.metrics.recall
8378
except AttributeError:
84-
return None
85-
86-
dataframe = function(
87-
data_source=self.data_source, aggregate=self.aggregate, average="macro"
88-
)
89-
90-
return cast_to_float(dataframe.iloc[0, 0])
79+
self.value = None
80+
else:
81+
dataframe = function(
82+
data_source=self.data_source, aggregate=self.aggregate, average="macro"
83+
)
84+
self.value = cast_to_float(dataframe.iloc[0, 0])
9185

9286

9387
class RecallTrainStd(RecallStd): # noqa: D101

skore-hub-project/src/skore_hub_project/metric/timing.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@
22

33
from __future__ import annotations
44

5-
from functools import cached_property
65
from typing import ClassVar, Literal
76

8-
from pydantic import computed_field
9-
107
from skore_hub_project.protocol import CrossValidationReport, EstimatorReport
118

129
from .metric import Metric, cast_to_float
@@ -19,13 +16,12 @@ class FitTime(Metric[EstimatorReport]): # noqa: D101
1916
position: int = 1
2017
data_source: None = None
2118

22-
@computed_field # type: ignore[prop-decorator]
23-
@cached_property
24-
def value(self) -> float | None: # noqa: D102
19+
def compute(self) -> None:
20+
"""Compute the value of the metric."""
2521
timings = self.report.metrics.timings()
2622
fit_time = timings.get("fit_time")
2723

28-
return cast_to_float(fit_time)
24+
self.value = cast_to_float(fit_time)
2925

3026

3127
class FitTimeAggregate(Metric[CrossValidationReport]): # noqa: D101
@@ -40,17 +36,16 @@ class FitTimeAggregate(Metric[CrossValidationReport]): # noqa: D101
4036
greater_is_better: bool = False
4137
data_source: None = None
4238

43-
@computed_field # type: ignore[prop-decorator]
44-
@cached_property
45-
def value(self) -> float | None: # noqa: D102
39+
def compute(self) -> None:
40+
"""Compute the value of the metric."""
4641
timings = self.report.metrics.timings(aggregate=self.aggregate)
4742

4843
try:
4944
fit_times = timings.loc["Fit time (s)"]
5045
except KeyError:
51-
return None
52-
53-
return cast_to_float(fit_times.iloc[0])
46+
self.value = None
47+
else:
48+
self.value = cast_to_float(fit_times.iloc[0])
5449

5550

5651
class FitTimeMean(FitTimeAggregate): # noqa: D101
@@ -73,13 +68,12 @@ class PredictTime(Metric[EstimatorReport]): # noqa: D101
7368
greater_is_better: bool = False
7469
position: int = 2
7570

76-
@computed_field # type: ignore[prop-decorator]
77-
@cached_property
78-
def value(self) -> float | None: # noqa: D102
71+
def compute(self) -> None:
72+
"""Compute the value of the metric."""
7973
timings = self.report.metrics.timings()
8074
predict_time = timings.get(f"predict_time_{self.data_source}")
8175

82-
return cast_to_float(predict_time)
76+
self.value = cast_to_float(predict_time)
8377

8478

8579
class PredictTimeTrain(PredictTime): # noqa: D101
@@ -101,17 +95,16 @@ class PredictTimeAggregate(Metric[CrossValidationReport]): # noqa: D101
10195
aggregate: ClassVar[Literal["mean", "std"]]
10296
greater_is_better: bool = False
10397

104-
@computed_field # type: ignore[prop-decorator]
105-
@cached_property
106-
def value(self) -> float | None: # noqa: D102
98+
def compute(self) -> None:
99+
"""Compute the value of the metric."""
107100
timings = self.report.metrics.timings(aggregate=self.aggregate)
108101

109102
try:
110103
predict_times = timings.loc[f"Predict time {self.data_source} (s)"]
111104
except KeyError:
112-
return None
113-
114-
return cast_to_float(predict_times.iloc[0])
105+
self.value = None
106+
else:
107+
self.value = cast_to_float(predict_times.iloc[0])
115108

116109

117110
class PredictTimeMean(PredictTimeAggregate): # noqa: D101

skore-hub-project/src/skore_hub_project/protocol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class EstimatorReport(Protocol):
2525
"""Protocol equivalent to ``skore.EstimatorReport``."""
2626

2727
_hash: int
28+
cache_predictions: Any
2829
clear_cache: Any
2930
_cache: Any
3031
metrics: Any
@@ -44,6 +45,7 @@ class CrossValidationReport(Protocol):
4445
"""Protocol equivalent to ``skore.CrossValidationReport``."""
4546

4647
_hash: int
48+
cache_predictions: Any
4749
clear_cache: Any
4850
_cache: Any
4951
metrics: Any

0 commit comments

Comments
 (0)