Skip to content

Commit c59c4ef

Browse files
committed
Merge branch 'main' into skore-make-cache-thread-safe
2 parents b36174b + 83bdcdc commit c59c4ef

File tree

6 files changed

+185
-133
lines changed

6 files changed

+185
-133
lines changed

skore/src/skore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from skore._externals._sklearn_compat import parse_version
1212
from skore._sklearn import (
1313
ComparisonReport,
14+
ConfusionMatrixDisplay,
1415
CrossValidationReport,
1516
EstimatorReport,
1617
MetricsSummaryDisplay,
@@ -31,6 +32,7 @@
3132
__all__ = [
3233
"CrossValidationReport",
3334
"ComparisonReport",
35+
"ConfusionMatrixDisplay",
3436
"Display",
3537
"EstimatorReport",
3638
"PrecisionRecallCurveDisplay",

skore/src/skore/_sklearn/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from skore._sklearn._cross_validation import CrossValidationReport
55
from skore._sklearn._estimator import EstimatorReport
66
from skore._sklearn._plot import (
7+
ConfusionMatrixDisplay,
78
MetricsSummaryDisplay,
89
PrecisionRecallCurveDisplay,
910
PredictionErrorDisplay,
@@ -14,6 +15,7 @@
1415

1516
__all__ = [
1617
"ComparisonReport",
18+
"ConfusionMatrixDisplay",
1719
"CrossValidationReport",
1820
"EstimatorReport",
1921
"PrecisionRecallCurveDisplay",

skore/src/skore/_sklearn/_estimator/metrics_accessor.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,10 +1648,18 @@ def _get_display(
16481648
data_source: DataSource,
16491649
response_method: str | list[str] | tuple[str, ...],
16501650
display_class: type[
1651-
RocCurveDisplay | PrecisionRecallCurveDisplay | PredictionErrorDisplay
1651+
RocCurveDisplay
1652+
| PrecisionRecallCurveDisplay
1653+
| PredictionErrorDisplay
1654+
| ConfusionMatrixDisplay
16521655
],
16531656
display_kwargs: dict[str, Any],
1654-
) -> RocCurveDisplay | PrecisionRecallCurveDisplay | PredictionErrorDisplay:
1657+
) -> (
1658+
RocCurveDisplay
1659+
| PrecisionRecallCurveDisplay
1660+
| PredictionErrorDisplay
1661+
| ConfusionMatrixDisplay
1662+
):
16551663
"""Get the display from the cache or compute it.
16561664
16571665
Parameters
@@ -1676,7 +1684,7 @@ def _get_display(
16761684
The display class.
16771685
16781686
display_kwargs : dict
1679-
The display kwargs used by `display_class._from_predictions`.
1687+
The display kwargs used by `display_class._compute_data_for_display`.
16801688
16811689
Returns
16821690
-------
@@ -1692,7 +1700,11 @@ def _get_display(
16921700
cache_key = None
16931701
else:
16941702
cache_key_parts: list[Any] = [self._parent._hash, display_class.__name__]
1695-
cache_key_parts.extend(display_kwargs.values())
1703+
for kwarg in display_kwargs.values():
1704+
# NOTE: We cannot use lists in cache keys because they are not hashable
1705+
if isinstance(kwarg, list):
1706+
kwarg = tuple(kwarg)
1707+
cache_key_parts.append(kwarg)
16961708
if data_source_hash is not None:
16971709
cache_key_parts.append(data_source_hash)
16981710
else:
@@ -1984,11 +1996,8 @@ def confusion_matrix(
19841996
data_source: DataSource = "test",
19851997
X: ArrayLike | None = None,
19861998
y: ArrayLike | None = None,
1987-
sample_weight: ArrayLike | None = None,
19881999
display_labels: list | None = None,
1989-
include_values: bool = True,
1990-
normalize: Literal["true", "pred", "both"] | None = None,
1991-
values_format: str | None = None,
2000+
normalize: Literal["true", "pred", "all"] | None = None,
19922001
) -> ConfusionMatrixDisplay:
19932002
"""Plot the confusion matrix.
19942003
@@ -2012,25 +2021,15 @@ def confusion_matrix(
20122021
New target on which to compute the metric. By default, we use the target
20132022
provided when creating the report.
20142023
2015-
sample_weight : array-like of shape (n_samples,), default=None
2016-
Sample weights.
2017-
20182024
display_labels : list of str, default=None
20192025
Display labels for plot. If None, display labels are set from 0 to
20202026
``n_classes - 1``.
20212027
2022-
include_values : bool, default=True
2023-
Includes values in confusion matrix.
2024-
20252028
normalize : {'true', 'pred', 'all'}, default=None
20262029
Normalizes confusion matrix over the true (rows), predicted (columns)
20272030
conditions or all the population. If None, confusion matrix will not be
20282031
normalized.
20292032
2030-
values_format : str, default=None
2031-
Format specification for values in confusion matrix. If None, the format
2032-
specification is 'd' or '.2g' whichever is shorter.
2033-
20342033
Returns
20352034
-------
20362035
display : :class:`~skore._sklearn._plot.ConfusionMatrixDisplay`
@@ -2048,21 +2047,16 @@ def confusion_matrix(
20482047
>>> report = EstimatorReport(classifier, **split_data)
20492048
>>> report.metrics.confusion_matrix()
20502049
"""
2051-
X, y, _ = self._get_X_y_and_data_source_hash(data_source=data_source, X=X, y=y)
2052-
2053-
y_pred = self._parent.get_predictions(
2054-
data_source=data_source,
2055-
response_method="predict",
2056-
X=X,
2057-
pos_label=None,
2058-
)
2059-
2060-
return ConfusionMatrixDisplay.from_predictions(
2061-
y_true=y,
2062-
y_pred=y_pred,
2063-
sample_weight=sample_weight,
2064-
display_labels=display_labels,
2065-
include_values=include_values,
2066-
normalize=normalize,
2067-
values_format=values_format,
2050+
display_kwargs = {"display_labels": display_labels, "normalize": normalize}
2051+
display = cast(
2052+
ConfusionMatrixDisplay,
2053+
self._get_display(
2054+
X=X,
2055+
y=y,
2056+
data_source=data_source,
2057+
response_method="predict",
2058+
display_class=ConfusionMatrixDisplay,
2059+
display_kwargs=display_kwargs,
2060+
),
20682061
)
2062+
return display

0 commit comments

Comments
 (0)