Skip to content

Commit cfb2905

Browse files
committed
init
1 parent d232d29 commit cfb2905

File tree

4 files changed

+216
-42
lines changed

4 files changed

+216
-42
lines changed

examples/model_evaluation/plot_estimator_report.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,14 +397,12 @@ def operational_decision_cost(y_true, y_pred, amount):
397397
# %%
398398
# We can normalize the confusion matrix to get percentages instead of raw counts.
399399
# Here we normalize by true labels (rows):
400-
cm_display = report.metrics.confusion_matrix()
401400
cm_display.plot(normalize="true")
402401
plt.show()
403402

404403
# %%
405404
# More plotting options are available via ``heatmap_kwargs``, which are passed to
406405
# seaborn's heatmap. For example, we can customize the colormap and number format:
407-
cm_display = report.metrics.confusion_matrix()
408406
cm_display.plot(heatmap_kwargs={"cmap": "Greens", "fmt": ".2e"})
409407
plt.show()
410408

skore/src/skore/_externals/_sklearn_compat.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,3 +869,20 @@ def parametrize_with_checks(
869869
check_X_y, # noqa: F401
870870
validate_data, # noqa: F401
871871
)
872+
873+
874+
########################################################################################
875+
# Upgrading for scikit-learn 1.8
876+
########################################################################################
877+
878+
if sklearn_version < parse_version("1.8"):
879+
880+
def confusion_matrix_at_thresholds(y_true, y_pred, pos_label):
881+
from sklearn.metrics._ranking import _binary_clf_curve
882+
883+
fps, tps, thresholds = _binary_clf_curve(y_true, y_pred, pos_label=pos_label)
884+
fns = tps[-1] - tps
885+
tns = fps[-1] - fps
886+
return tns, fps, fns, tps, thresholds
887+
else:
888+
from sklearn.metrics._ranking import confusion_matrix_at_thresholds # noqa: F401

skore/src/skore/_sklearn/_estimator/metrics_accessor.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,6 +2057,7 @@ def confusion_matrix(
20572057
data_source: DataSource = "test",
20582058
X: ArrayLike | None = None,
20592059
y: ArrayLike | None = None,
2060+
threshold: bool = False,
20602061
) -> ConfusionMatrixDisplay:
20612062
"""Plot the confusion matrix.
20622063
@@ -2080,6 +2081,14 @@ def confusion_matrix(
20802081
New target on which to compute the metric. By default, we use the target
20812082
provided when creating the report.
20822083
2084+
threshold : bool, default=False
2085+
Whether to enable decision threshold support for binary classification.
2086+
When True, the display will precompute confusion matrices at all possible
2087+
decision thresholds, allowing you to specify a threshold in `.plot()` or
2088+
`.frame()` methods. This is only applicable for binary classification and
2089+
requires the estimator to have `predict_proba` or `decision_function`
2090+
methods.
2091+
20832092
Returns
20842093
-------
20852094
display : :class:`~skore._sklearn._plot.ConfusionMatrixDisplay`
@@ -2095,16 +2104,32 @@ def confusion_matrix(
20952104
>>> split_data = train_test_split(X=X, y=y, random_state=0, as_dict=True)
20962105
>>> classifier = LogisticRegression(max_iter=10_000)
20972106
>>> report = EstimatorReport(classifier, **split_data)
2098-
>>> report.metrics.confusion_matrix()
2107+
>>> display = report.metrics.confusion_matrix()
2108+
>>> display.plot()
2109+
2110+
With decision threshold support for binary classification:
2111+
2112+
>>> display = report.metrics.confusion_matrix(threshold=True)
2113+
>>> display.plot(threshold=0.7)
20992114
"""
2100-
display_kwargs = {"display_labels": self._parent.estimator_.classes_.tolist()}
2115+
display_kwargs = {
2116+
"display_labels": self._parent.estimator_.classes_.tolist(),
2117+
"pos_label": self._parent.pos_label,
2118+
"threshold": threshold,
2119+
}
2120+
2121+
response_method: str | list[str] | tuple[str, ...]
2122+
if threshold and self._parent._ml_task == "binary-classification":
2123+
response_method = ("predict_proba", "decision_function")
2124+
else:
2125+
response_method = "predict"
21012126
display = cast(
21022127
ConfusionMatrixDisplay,
21032128
self._get_display(
21042129
X=X,
21052130
y=y,
21062131
data_source=data_source,
2107-
response_method="predict",
2132+
response_method=response_method,
21082133
display_class=ConfusionMatrixDisplay,
21092134
display_kwargs=display_kwargs,
21102135
),

0 commit comments

Comments
 (0)