@@ -2057,6 +2057,7 @@ def confusion_matrix(
20572057 data_source : DataSource = "test" ,
20582058 X : ArrayLike | None = None ,
20592059 y : ArrayLike | None = None ,
2060+ threshold : bool = False ,
20602061 ) -> ConfusionMatrixDisplay :
20612062 """Plot the confusion matrix.
20622063
@@ -2080,6 +2081,14 @@ def confusion_matrix(
20802081 New target on which to compute the metric. By default, we use the target
20812082 provided when creating the report.
20822083
2084+ threshold : bool, default=False
2085+ Whether to enable decision threshold support for binary classification.
2086+ When True, the display will precompute confusion matrices at all possible
2087+ decision thresholds, allowing you to specify a threshold in `.plot()` or
2088+ `.frame()` methods. This is only applicable for binary classification and
2089+ requires the estimator to have `predict_proba` or `decision_function`
2090+ methods.
2091+
20832092 Returns
20842093 -------
20852094 display : :class:`~skore._sklearn._plot.ConfusionMatrixDisplay`
@@ -2095,16 +2104,32 @@ def confusion_matrix(
20952104 >>> split_data = train_test_split(X=X, y=y, random_state=0, as_dict=True)
20962105 >>> classifier = LogisticRegression(max_iter=10_000)
20972106 >>> report = EstimatorReport(classifier, **split_data)
2098- >>> report.metrics.confusion_matrix()
2107+ >>> display = report.metrics.confusion_matrix()
2108+ >>> display.plot()
2109+
2110+ With decision threshold support for binary classification:
2111+
2112+ >>> display = report.metrics.confusion_matrix(threshold=True)
2113+ >>> display.plot(threshold=0.7)
20992114 """
2100- display_kwargs = {"display_labels" : self ._parent .estimator_ .classes_ .tolist ()}
2115+ display_kwargs = {
2116+ "display_labels" : self ._parent .estimator_ .classes_ .tolist (),
2117+ "pos_label" : self ._parent .pos_label ,
2118+ "threshold" : threshold ,
2119+ }
2120+
2121+ response_method : str | list [str ] | tuple [str , ...]
2122+ if threshold and self ._parent ._ml_task == "binary-classification" :
2123+ response_method = ("predict_proba" , "decision_function" )
2124+ else :
2125+ response_method = "predict"
21012126 display = cast (
21022127 ConfusionMatrixDisplay ,
21032128 self ._get_display (
21042129 X = X ,
21052130 y = y ,
21062131 data_source = data_source ,
2107- response_method = "predict" ,
2132+ response_method = response_method ,
21082133 display_class = ConfusionMatrixDisplay ,
21092134 display_kwargs = display_kwargs ,
21102135 ),
0 commit comments