@@ -1648,10 +1648,18 @@ def _get_display(
16481648 data_source : DataSource ,
16491649 response_method : str | list [str ] | tuple [str , ...],
16501650 display_class : type [
1651- RocCurveDisplay | PrecisionRecallCurveDisplay | PredictionErrorDisplay
1651+ RocCurveDisplay
1652+ | PrecisionRecallCurveDisplay
1653+ | PredictionErrorDisplay
1654+ | ConfusionMatrixDisplay
16521655 ],
16531656 display_kwargs : dict [str , Any ],
1654- ) -> RocCurveDisplay | PrecisionRecallCurveDisplay | PredictionErrorDisplay :
1657+ ) -> (
1658+ RocCurveDisplay
1659+ | PrecisionRecallCurveDisplay
1660+ | PredictionErrorDisplay
1661+ | ConfusionMatrixDisplay
1662+ ):
16551663 """Get the display from the cache or compute it.
16561664
16571665 Parameters
@@ -1676,7 +1684,7 @@ def _get_display(
16761684 The display class.
16771685
16781686 display_kwargs : dict
1679- The display kwargs used by `display_class._from_predictions `.
1687+ The display kwargs used by `display_class._compute_data_for_display `.
16801688
16811689 Returns
16821690 -------
@@ -1692,7 +1700,11 @@ def _get_display(
16921700 cache_key = None
16931701 else :
16941702 cache_key_parts : list [Any ] = [self ._parent ._hash , display_class .__name__ ]
1695- cache_key_parts .extend (display_kwargs .values ())
1703+ for kwarg in display_kwargs .values ():
1704+ # NOTE: We cannot use lists in cache keys because they are not hashable
1705+ if isinstance (kwarg , list ):
1706+ kwarg = tuple (kwarg )
1707+ cache_key_parts .append (kwarg )
16961708 if data_source_hash is not None :
16971709 cache_key_parts .append (data_source_hash )
16981710 else :
@@ -1984,11 +1996,8 @@ def confusion_matrix(
19841996 data_source : DataSource = "test" ,
19851997 X : ArrayLike | None = None ,
19861998 y : ArrayLike | None = None ,
1987- sample_weight : ArrayLike | None = None ,
19881999 display_labels : list | None = None ,
1989- include_values : bool = True ,
1990- normalize : Literal ["true" , "pred" , "both" ] | None = None ,
1991- values_format : str | None = None ,
2000+ normalize : Literal ["true" , "pred" , "all" ] | None = None ,
19922001 ) -> ConfusionMatrixDisplay :
19932002 """Plot the confusion matrix.
19942003
@@ -2012,25 +2021,15 @@ def confusion_matrix(
20122021 New target on which to compute the metric. By default, we use the target
20132022 provided when creating the report.
20142023
2015- sample_weight : array-like of shape (n_samples,), default=None
2016- Sample weights.
2017-
20182024 display_labels : list of str, default=None
20192025 Display labels for plot. If None, display labels are set from 0 to
20202026 ``n_classes - 1``.
20212027
2022- include_values : bool, default=True
2023- Includes values in confusion matrix.
2024-
20252028 normalize : {'true', 'pred', 'all'}, default=None
20262029 Normalizes confusion matrix over the true (rows), predicted (columns)
20272030 conditions or all the population. If None, confusion matrix will not be
20282031 normalized.
20292032
2030- values_format : str, default=None
2031- Format specification for values in confusion matrix. If None, the format
2032- specification is 'd' or '.2g' whichever is shorter.
2033-
20342033 Returns
20352034 -------
20362035 display : :class:`~skore._sklearn._plot.ConfusionMatrixDisplay`
@@ -2048,21 +2047,16 @@ def confusion_matrix(
20482047 >>> report = EstimatorReport(classifier, **split_data)
20492048 >>> report.metrics.confusion_matrix()
20502049 """
2051- X , y , _ = self ._get_X_y_and_data_source_hash (data_source = data_source , X = X , y = y )
2052-
2053- y_pred = self ._parent .get_predictions (
2054- data_source = data_source ,
2055- response_method = "predict" ,
2056- X = X ,
2057- pos_label = None ,
2058- )
2059-
2060- return ConfusionMatrixDisplay .from_predictions (
2061- y_true = y ,
2062- y_pred = y_pred ,
2063- sample_weight = sample_weight ,
2064- display_labels = display_labels ,
2065- include_values = include_values ,
2066- normalize = normalize ,
2067- values_format = values_format ,
2050+ display_kwargs = {"display_labels" : display_labels , "normalize" : normalize }
2051+ display = cast (
2052+ ConfusionMatrixDisplay ,
2053+ self ._get_display (
2054+ X = X ,
2055+ y = y ,
2056+ data_source = data_source ,
2057+ response_method = "predict" ,
2058+ display_class = ConfusionMatrixDisplay ,
2059+ display_kwargs = display_kwargs ,
2060+ ),
20682061 )
2062+ return display
0 commit comments