Skip to content

Commit d232d29

Browse files
committed
Merge branch 'main' into followup_1759
2 parents e2c7e2c + b0b2c9f commit d232d29

File tree

16 files changed

+442
-110
lines changed

16 files changed

+442
-110
lines changed

skore/src/skore/_sklearn/_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from skore._externals._sklearn_compat import is_clusterer
1616
from skore._sklearn.types import PositiveLabel
17+
from skore._utils._cache import Cache
1718
from skore._utils._measure_time import MeasureTime
1819

1920

@@ -120,7 +121,7 @@ class _BaseReport(_HelpMixin):
120121
_X_test: ArrayLike | None
121122
_y_train: ArrayLike | None
122123
_y_test: ArrayLike | None
123-
_cache: dict[tuple[Any, ...], Any]
124+
_cache: Cache
124125
estimator_: BaseEstimator
125126

126127
def _get_help_panel_title(self) -> str:
@@ -323,7 +324,7 @@ def _get_X_y_and_data_source_hash(
323324

324325
def _get_cached_response_values(
325326
*,
326-
cache: dict[tuple[Any, ...], Any],
327+
cache: Cache,
327328
estimator_hash: int,
328329
estimator: BaseEstimator,
329330
X: ArrayLike | None,

skore/src/skore/_sklearn/_comparison/metrics_accessor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,7 @@ def _get_display(
12241224
y_true.append(
12251225
YPlotData(
12261226
estimator_name=report_name,
1227+
data_source=data_source,
12271228
split=None,
12281229
y=report_y,
12291230
)
@@ -1245,6 +1246,7 @@ def _get_display(
12451246
y_pred.append(
12461247
YPlotData(
12471248
estimator_name=report_name,
1249+
data_source=data_source,
12481250
split=None,
12491251
y=value,
12501252
)
@@ -1278,6 +1280,7 @@ def _get_display(
12781280
y_true.append(
12791281
YPlotData(
12801282
estimator_name=report_name,
1283+
data_source=data_source,
12811284
split=split,
12821285
y=report_y,
12831286
)
@@ -1300,6 +1303,7 @@ def _get_display(
13001303
y_pred.append(
13011304
YPlotData(
13021305
estimator_name=report_name,
1306+
data_source=data_source,
13031307
split=split,
13041308
y=value,
13051309
)

skore/src/skore/_sklearn/_comparison/report.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from skore._sklearn._cross_validation.report import CrossValidationReport
1515
from skore._sklearn._estimator.report import EstimatorReport
1616
from skore._sklearn.types import _DEFAULT, PositiveLabel
17+
from skore._utils._cache import Cache
1718
from skore._utils._progress_bar import progress_decorator
1819

1920
if TYPE_CHECKING:
@@ -251,7 +252,7 @@ def __init__(
251252
self._hash = self._rng.integers(
252253
low=np.iinfo(np.int64).min, high=np.iinfo(np.int64).max
253254
)
254-
self._cache: dict[tuple[Any, ...], Any] = {}
255+
self._cache = Cache()
255256
self._ml_task = next(iter(self.reports_.values()))._ml_task # type: ignore
256257

257258
def clear_cache(self) -> None:
@@ -277,7 +278,8 @@ def clear_cache(self) -> None:
277278
"""
278279
for report in self.reports_.values():
279280
report.clear_cache()
280-
self._cache = {}
281+
282+
self._cache = Cache()
281283

282284
@progress_decorator(description="Estimator predictions")
283285
def cache_predictions(

skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,7 @@ def _get_display(
11291129
total_estimators = len(self._parent.estimator_reports_)
11301130
progress.update(main_task, total=total_estimators)
11311131

1132-
if cache_key in self._parent._cache:
1132+
if cache_key and cache_key in self._parent._cache:
11331133
display = self._parent._cache[cache_key]
11341134
else:
11351135
y_true: list[YPlotData] = []
@@ -1145,6 +1145,7 @@ def _get_display(
11451145
y_true.append(
11461146
YPlotData(
11471147
estimator_name=self._parent.estimator_name_,
1148+
data_source=data_source,
11481149
split=report_idx,
11491150
y=cast(ArrayLike, y),
11501151
)
@@ -1166,6 +1167,7 @@ def _get_display(
11661167
y_pred.append(
11671168
YPlotData(
11681169
estimator_name=self._parent.estimator_name_,
1170+
data_source=data_source,
11691171
split=report_idx,
11701172
y=value,
11711173
)

skore/src/skore/_sklearn/_cross_validation/report.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from skore._sklearn._estimator.report import EstimatorReport
1818
from skore._sklearn.find_ml_task import _find_ml_task
1919
from skore._sklearn.types import _DEFAULT, PositiveLabel, SKLearnCrossValidator
20+
from skore._utils._cache import Cache
2021
from skore._utils._fixes import _validate_joblib_parallel_params
2122
from skore._utils._parallel import Parallel, delayed
2223
from skore._utils._progress_bar import progress_decorator
@@ -177,7 +178,7 @@ def __init__(
177178
self._hash = self._rng.integers(
178179
low=np.iinfo(np.int64).min, high=np.iinfo(np.int64).max
179180
)
180-
self._cache: dict[tuple[Any, ...], Any] = {}
181+
self._cache = Cache()
181182
self._ml_task = _find_ml_task(
182183
y, estimator=self.estimator_reports_[0]._estimator
183184
)
@@ -296,7 +297,8 @@ def clear_cache(self) -> None:
296297
"""
297298
for report in self.estimator_reports_:
298299
report.clear_cache()
299-
self._cache = {}
300+
301+
self._cache = Cache()
300302

301303
@progress_decorator(description="Cross-validation predictions")
302304
def cache_predictions(

skore/src/skore/_sklearn/_estimator/metrics_accessor.py

Lines changed: 102 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from skore._sklearn.types import (
3030
_DEFAULT,
31+
DataSource,
3132
PositiveLabel,
3233
Scoring,
3334
YPlotData,
@@ -40,8 +41,6 @@
4041
)
4142
from skore._utils._index import flatten_multi_index
4243

43-
DataSource = Literal["test", "train", "X_y"]
44-
4544

4645
class _MetricsAccessor(
4746
_BaseMetricsAccessor, _BaseAccessor["EstimatorReport"], DirNamesMixin
@@ -1645,7 +1644,7 @@ def _get_display(
16451644
*,
16461645
X: ArrayLike | None,
16471646
y: ArrayLike | None,
1648-
data_source: DataSource,
1647+
data_source: DataSource | Literal["both"],
16491648
response_method: str | list[str] | tuple[str, ...],
16501649
display_class: type[
16511650
RocCurveDisplay
@@ -1670,12 +1669,13 @@ def _get_display(
16701669
y : array-like of shape (n_samples,)
16711670
The target.
16721671
1673-
data_source : {"test", "train", "X_y"}, default="test"
1672+
data_source : {"test", "train", "X_y", "both"}, default="test"
16741673
The data source to use.
16751674
16761675
- "test" : use the test set provided when creating the report.
16771676
- "train" : use the train set provided when creating the report.
16781677
- "X_y" : use the provided `X` and `y` to compute the metric.
1678+
- "both" : use both the train set and the test set to compute the metric.
16791679
16801680
response_method : str, list of str or tuple of str
16811681
The response method.
@@ -1691,11 +1691,98 @@ def _get_display(
16911691
display : display_class
16921692
The display.
16931693
"""
1694-
X, y, data_source_hash = self._get_X_y_and_data_source_hash(
1695-
data_source=data_source, X=X, y=y
1696-
)
1697-
assert y is not None, "y must be provided"
1694+
pos_label = display_kwargs.get("pos_label")
1695+
1696+
def get_ys(
1697+
*,
1698+
X,
1699+
y_true,
1700+
data_source,
1701+
data_source_hash,
1702+
cache=self._parent._cache,
1703+
estimator_hash=int(self._parent._hash),
1704+
estimator=self._parent.estimator_,
1705+
estimator_name=self._parent.estimator_name_,
1706+
response_method=response_method,
1707+
pos_label=pos_label,
1708+
) -> tuple[list[YPlotData], list[YPlotData]]:
1709+
"""Get predictions and format y_true and y_pred using YPlotData."""
1710+
results = _get_cached_response_values(
1711+
cache=cache,
1712+
estimator_hash=estimator_hash,
1713+
estimator=estimator,
1714+
X=X,
1715+
response_method=response_method,
1716+
pos_label=pos_label,
1717+
data_source=data_source,
1718+
data_source_hash=data_source_hash,
1719+
)
1720+
for key, value, is_cached in results:
1721+
key = cast(tuple[Any, ...], key)
1722+
if not is_cached:
1723+
cache[key] = value
1724+
if key[-1] != "predict_time":
1725+
y_pred = value
1726+
1727+
y_true = [
1728+
YPlotData(
1729+
estimator_name=estimator_name,
1730+
data_source=data_source,
1731+
split=None,
1732+
y=y_true,
1733+
)
1734+
]
1735+
y_pred = [
1736+
YPlotData(
1737+
estimator_name=estimator_name,
1738+
data_source=data_source,
1739+
split=None,
1740+
y=y_pred,
1741+
)
1742+
]
1743+
return y_true, y_pred
1744+
1745+
if data_source == "both":
1746+
X_train, y_train, data_source_hash_train = (
1747+
self._get_X_y_and_data_source_hash(data_source="train", X=X, y=y)
1748+
)
1749+
y_train_true, y_train_pred = get_ys(
1750+
X=X_train,
1751+
y_true=y_train,
1752+
data_source="train",
1753+
data_source_hash=data_source_hash_train,
1754+
)
1755+
assert y_train_true is not None, "y must be provided"
1756+
1757+
X_test, y_test, data_source_hash_test = self._get_X_y_and_data_source_hash(
1758+
data_source="test", X=X, y=y
1759+
)
1760+
y_test_true, y_test_pred = get_ys(
1761+
X=X_test,
1762+
y_true=y_test,
1763+
data_source="test",
1764+
data_source_hash=data_source_hash_test,
1765+
)
1766+
assert y_test_true is not None, "y must be provided"
16981767

1768+
y_true = y_train_true + y_test_true
1769+
y_pred = y_train_pred + y_test_pred
1770+
data_source_hash = None
1771+
else:
1772+
X, y_data_source, data_source_hash = self._get_X_y_and_data_source_hash(
1773+
data_source=data_source, X=X, y=y
1774+
)
1775+
1776+
y_true, y_pred = get_ys(
1777+
X=X,
1778+
y_true=y_data_source,
1779+
data_source=data_source,
1780+
data_source_hash=data_source_hash,
1781+
)
1782+
1783+
assert y_true is not None, "y must be provided"
1784+
1785+
# Compute cache key
16991786
if "seed" in display_kwargs and display_kwargs["seed"] is None:
17001787
cache_key = None
17011788
else:
@@ -1711,40 +1798,12 @@ def _get_display(
17111798
cache_key_parts.append(data_source)
17121799
cache_key = tuple(cache_key_parts)
17131800

1714-
if cache_key in self._parent._cache:
1801+
if cache_key and cache_key in self._parent._cache:
17151802
display = self._parent._cache[cache_key]
17161803
else:
1717-
results = _get_cached_response_values(
1718-
cache=self._parent._cache,
1719-
estimator_hash=int(self._parent._hash),
1720-
estimator=self._parent.estimator_,
1721-
X=X,
1722-
response_method=response_method,
1723-
pos_label=display_kwargs.get("pos_label"),
1724-
data_source=data_source,
1725-
data_source_hash=data_source_hash,
1726-
)
1727-
for key, value, is_cached in results:
1728-
if not is_cached:
1729-
self._parent._cache[cast(tuple[Any, ...], key)] = value
1730-
if cast(tuple[Any, ...], key)[-1] != "predict_time":
1731-
y_pred = value
1732-
17331804
display = display_class._compute_data_for_display(
1734-
y_true=[
1735-
YPlotData(
1736-
estimator_name=self._parent.estimator_name_,
1737-
split=None,
1738-
y=y,
1739-
)
1740-
],
1741-
y_pred=[
1742-
YPlotData(
1743-
estimator_name=self._parent.estimator_name_,
1744-
split=None,
1745-
y=y_pred,
1746-
)
1747-
],
1805+
y_true=y_true,
1806+
y_pred=y_pred,
17481807
report_type="estimator",
17491808
estimators=[self._parent.estimator_],
17501809
ml_task=self._parent._ml_task,
@@ -1767,7 +1826,7 @@ def _get_display(
17671826
def roc(
17681827
self,
17691828
*,
1770-
data_source: DataSource = "test",
1829+
data_source: DataSource | Literal["both"] = "test",
17711830
X: ArrayLike | None = None,
17721831
y: ArrayLike | None = None,
17731832
pos_label: PositiveLabel | None = _DEFAULT,
@@ -1776,12 +1835,14 @@ def roc(
17761835
17771836
Parameters
17781837
----------
1779-
data_source : {"test", "train", "X_y"}, default="test"
1838+
data_source : {"test", "train", "X_y", "both"}, default="test"
17801839
The data source to use.
17811840
17821841
- "test" : use the test set provided when creating the report.
17831842
- "train" : use the train set provided when creating the report.
17841843
- "X_y" : use the provided `X` and `y` to compute the metric.
1844+
- "both" : use both the train and test sets to compute the metrics and
1845+
present them side-by-side.
17851846
17861847
X : array-like of shape (n_samples, n_features), default=None
17871848
New data on which to compute the metric. By default, we use the validation

skore/src/skore/_sklearn/_estimator/report.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from skore._sklearn._base import _BaseReport, _get_cached_response_values
1919
from skore._sklearn.find_ml_task import _find_ml_task
2020
from skore._sklearn.types import _DEFAULT, PositiveLabel
21+
from skore._utils._cache import Cache
2122
from skore._utils._fixes import _validate_joblib_parallel_params
2223
from skore._utils._measure_time import MeasureTime
2324
from skore._utils._parallel import Parallel, delayed
@@ -186,7 +187,7 @@ def _initialize_state(self) -> None:
186187
self._hash = self._rng.integers(
187188
low=np.iinfo(np.int64).min, high=np.iinfo(np.int64).max
188189
)
189-
self._cache: dict[tuple[Any, ...], Any] = {}
190+
self._cache = Cache()
190191
self._ml_task = _find_ml_task(self._y_test, estimator=self._estimator)
191192

192193
# NOTE:
@@ -214,7 +215,7 @@ def clear_cache(self) -> None:
214215
>>> report._cache
215216
{}
216217
"""
217-
self._cache = {}
218+
self._cache = Cache()
218219

219220
@progress_decorator(description="Caching predictions")
220221
def cache_predictions(

0 commit comments

Comments
 (0)