Skip to content

Commit d0e5f1f

Browse files
committed
iter
1 parent cfb2905 commit d0e5f1f

File tree

3 files changed

+395
-12
lines changed

3 files changed

+395
-12
lines changed

examples/model_evaluation/plot_estimator_report.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,41 @@ def operational_decision_cost(y_true, y_pred, amount):
412412
cm_frame = cm_display.frame()
413413
cm_frame
414414

415+
# %%
416+
# Decision threshold support
417+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
418+
#
419+
# For binary classification, the confusion matrix can be computed at different
420+
# decision thresholds. This is useful for understanding how the model's predictions
421+
# change as the decision threshold varies.
422+
423+
# %%
424+
# First, we create a display with threshold support enabled:
425+
cm_threshold_display = report.metrics.confusion_matrix(threshold=True)
426+
427+
# %%
428+
# Now we can plot the confusion matrix at a specific threshold:
429+
cm_threshold_display.plot(threshold=0.3)
430+
plt.show()
431+
432+
# %%
433+
# The title shows the threshold value used. By default, the threshold closest to
434+
# the requested value is selected from the available thresholds.
435+
#
436+
# We can also compare multiple thresholds side by side:
437+
cm_threshold_display.plot(threshold=[0.3, 0.5, 0.7])
438+
plt.show()
439+
440+
# %%
441+
# The frame method also supports threshold selection:
442+
cm_threshold_display.frame(threshold=0.7)
443+
444+
# %%
445+
# When no threshold is specified for a threshold-enabled display, we get all
446+
# confusion matrices for all available thresholds:
447+
cm_all_thresholds = cm_threshold_display.frame()
448+
cm_all_thresholds.head(10)
449+
415450
# %%
416451
# .. seealso::
417452
#

skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py

Lines changed: 111 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def plot(
7373
self,
7474
*,
7575
normalize: Literal["true", "pred", "all"] | None = None,
76-
threshold: float | None = None,
76+
threshold: float | list[float] | None = None,
7777
heatmap_kwargs: dict | None = None,
7878
):
7979
"""Plot visualization.
@@ -85,10 +85,11 @@ def plot(
8585
conditions or all the population. If None, the confusion matrix will not be
8686
normalized.
8787
88-
threshold : float or None, default=None
89-
The decision threshold to use for binary classification. If None,
88+
threshold : float, list of float, or None, default=None
89+
The decision threshold(s) to use for binary classification. If None,
9090
uses the default threshold (0.5 if thresholds are available, or the
91-
predicted labels if not).
91+
predicted labels if not). If a list of floats is provided, multiple
92+
confusion matrices will be plotted side by side.
9293
9394
heatmap_kwargs : dict, default=None
9495
Additional keyword arguments to be passed to seaborn's `sns.heatmap`.
@@ -108,7 +109,7 @@ def _plot_matplotlib(
108109
self,
109110
*,
110111
normalize: Literal["true", "pred", "all"] | None = None,
111-
threshold: float | None = None,
112+
threshold: float | list[float] | None = None,
112113
heatmap_kwargs: dict | None = None,
113114
) -> None:
114115
"""Matplotlib implementation of the `plot` method."""
@@ -128,7 +129,7 @@ def _plot_single_estimator(
128129
self,
129130
*,
130131
normalize: Literal["true", "pred", "all"] | None = None,
131-
threshold: float | None = None,
132+
threshold: float | list[float] | None = None,
132133
heatmap_kwargs: dict | None = None,
133134
) -> None:
134135
"""
@@ -141,14 +142,29 @@ def _plot_single_estimator(
141142
conditions or all the population. If None, the confusion matrix will not be
142143
normalized.
143144
144-
threshold : float or None, default=None
145-
The decision threshold to use for binary classification. If None,
145+
threshold : float, list of float, or None, default=None
146+
The decision threshold(s) to use for binary classification. If None,
146147
uses the default threshold (0.5 if thresholds are available, or the
147-
predicted labels if not).
148+
predicted labels if not). If a list of floats is provided, multiple
149+
confusion matrices will be plotted side by side.
148150
149151
heatmap_kwargs : dict, default=None
150152
Additional keyword arguments to be passed to seaborn's `sns.heatmap`.
151153
"""
154+
# Handle multiple thresholds
155+
if isinstance(threshold, (list, tuple)):
156+
if not self.do_threshold:
157+
raise ValueError(
158+
"threshold can only be used with binary classification and "
159+
"when `report.metrics.confusion_matrix(threshold=True)` is used."
160+
)
161+
self._plot_multiple_thresholds(
162+
thresholds=threshold,
163+
normalize=normalize,
164+
heatmap_kwargs=heatmap_kwargs,
165+
)
166+
return
167+
152168
if threshold is not None:
153169
if not self.do_threshold:
154170
raise ValueError(
@@ -189,6 +205,67 @@ def _plot_single_estimator(
189205
self.ax_.set_title(title)
190206
self.figure_.tight_layout()
191207

208+
def _plot_multiple_thresholds(
209+
self,
210+
*,
211+
thresholds: list[float],
212+
normalize: Literal["true", "pred", "all"] | None = None,
213+
heatmap_kwargs: dict | None = None,
214+
) -> None:
215+
"""
216+
Plot multiple confusion matrices for different thresholds.
217+
218+
Parameters
219+
----------
220+
thresholds : list of float
221+
The decision thresholds to use.
222+
223+
normalize : {'true', 'pred', 'all'}, default=None
224+
Normalizes confusion matrix over the true (rows), predicted (columns)
225+
conditions or all the population. If None, the confusion matrix will not be
226+
normalized.
227+
228+
heatmap_kwargs : dict, default=None
229+
Additional keyword arguments to be passed to seaborn's `sns.heatmap`.
230+
"""
231+
n_thresholds = len(thresholds)
232+
figsize = (5 * n_thresholds, 4)
233+
self.figure_, axes = plt.subplots(1, n_thresholds, figsize=figsize)
234+
235+
# Handle single threshold case (axes won't be an array)
236+
if n_thresholds == 1:
237+
axes = [axes]
238+
239+
heatmap_kwargs_validated = _validate_style_kwargs(
240+
{"fmt": ".2f" if normalize else "d", **self._default_heatmap_kwargs},
241+
heatmap_kwargs or {},
242+
)
243+
# Disable colorbar for multi-threshold plots to avoid clutter
244+
heatmap_kwargs_validated["cbar"] = False
245+
246+
normalize_by = "normalized_by_" + normalize if normalize else "count"
247+
248+
for ax, thresh in zip(axes, thresholds, strict=True):
249+
# Find the existing threshold that is closest to the given threshold
250+
closest_threshold = self.thresholds_[
251+
np.argmin(np.abs(self.thresholds_ - thresh))
252+
]
253+
cm = self.confusion_matrix[
254+
self.confusion_matrix["threshold"] == closest_threshold
255+
]
256+
257+
sns.heatmap(
258+
cm.pivot(
259+
index="True label", columns="Predicted label", values=normalize_by
260+
),
261+
ax=ax,
262+
**heatmap_kwargs_validated,
263+
)
264+
ax.set_title(f"threshold: {closest_threshold:.2f}")
265+
266+
self.ax_ = axes[-1] # Set ax_ to the last axes for consistency
267+
self.figure_.tight_layout()
268+
192269
@classmethod
193270
def _compute_data_for_display(
194271
cls,
@@ -258,9 +335,31 @@ def _compute_data_for_display(
258335

259336
confusion_matrix_records = []
260337
for cm, threshold_value in zip(cms, thresholds, strict=True):
261-
cm_true = cm / cm.sum(axis=1, keepdims=True)
262-
cm_pred = cm / cm.sum(axis=0, keepdims=True)
263-
cm_all = cm / cm.sum()
338+
# Compute normalized values with proper handling of zero division
339+
with np.errstate(all="ignore"):
340+
row_sums = cm.sum(axis=1, keepdims=True)
341+
col_sums = cm.sum(axis=0, keepdims=True)
342+
total_sum = cm.sum()
343+
344+
cm_true = np.divide(
345+
cm,
346+
row_sums,
347+
out=np.zeros_like(cm, dtype=float),
348+
where=row_sums != 0,
349+
)
350+
cm_pred = np.divide(
351+
cm,
352+
col_sums,
353+
out=np.zeros_like(cm, dtype=float),
354+
where=col_sums != 0,
355+
)
356+
cm_all = np.divide(
357+
cm,
358+
total_sum,
359+
out=np.zeros_like(cm, dtype=float),
360+
where=total_sum != 0,
361+
)
362+
264363
n_classes = len(display_labels)
265364
true_labels = np.repeat(display_labels, n_classes)
266365
pred_labels = np.tile(display_labels, n_classes)

0 commit comments

Comments
 (0)