Skip to content

Commit cb01d7b

Browse files
committed
iter
1 parent ecc25fc commit cb01d7b

File tree

3 files changed

+60
-177
lines changed

3 files changed

+60
-177
lines changed

examples/model_evaluation/plot_estimator_report.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -430,14 +430,9 @@ def operational_decision_cost(y_true, y_pred, amount):
430430
plt.show()
431431

432432
# %%
433-
# The title shows the threshold value used. By default, the threshold closest to
434-
# the requested value is selected from the available thresholds.
433+
# Since there are a finite number of threshold where the predictions change,
434+
# we plot the decision matrix associated with the threshold closest to the one provided.
435435
#
436-
# We can also compare multiple thresholds side by side:
437-
cm_threshold_display.plot(threshold=[0.3, 0.5, 0.7])
438-
plt.show()
439-
440-
# %%
441436
# The frame method also supports threshold selection:
442437
cm_threshold_display.frame(threshold=0.7)
443438

skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py

Lines changed: 16 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,6 @@ def _plot_single_estimator(
151151
heatmap_kwargs : dict, default=None
152152
Additional keyword arguments to be passed to seaborn's `sns.heatmap`.
153153
"""
154-
# Handle multiple thresholds
155-
if isinstance(threshold, (list, tuple)):
156-
if not self.do_threshold:
157-
raise ValueError(
158-
"threshold can only be used with binary classification and "
159-
"when `report.metrics.confusion_matrix(threshold=True)` is used."
160-
)
161-
self._plot_multiple_thresholds(
162-
thresholds=threshold,
163-
normalize=normalize,
164-
heatmap_kwargs=heatmap_kwargs,
165-
)
166-
return
167-
168154
if threshold is not None:
169155
if not self.do_threshold:
170156
raise ValueError(
@@ -211,67 +197,6 @@ def _plot_single_estimator(
211197

212198
self.figure_.tight_layout()
213199

214-
def _plot_multiple_thresholds(
215-
self,
216-
*,
217-
thresholds: list[float],
218-
normalize: Literal["true", "pred", "all"] | None = None,
219-
heatmap_kwargs: dict | None = None,
220-
) -> None:
221-
"""
222-
Plot multiple confusion matrices for different thresholds.
223-
224-
Parameters
225-
----------
226-
thresholds : list of float
227-
The decision thresholds to use.
228-
229-
normalize : {'true', 'pred', 'all'}, default=None
230-
Normalizes confusion matrix over the true (rows), predicted (columns)
231-
conditions or all the population. If None, the confusion matrix will not be
232-
normalized.
233-
234-
heatmap_kwargs : dict, default=None
235-
Additional keyword arguments to be passed to seaborn's `sns.heatmap`.
236-
"""
237-
n_thresholds = len(thresholds)
238-
figsize = (5 * n_thresholds, 4)
239-
self.figure_, axes = plt.subplots(1, n_thresholds, figsize=figsize)
240-
241-
# Handle single threshold case (axes won't be an array)
242-
if n_thresholds == 1:
243-
axes = [axes]
244-
245-
heatmap_kwargs_validated = _validate_style_kwargs(
246-
{"fmt": ".2f" if normalize else "d", **self._default_heatmap_kwargs},
247-
heatmap_kwargs or {},
248-
)
249-
# Disable colorbar for multi-threshold plots to avoid clutter
250-
heatmap_kwargs_validated["cbar"] = False
251-
252-
normalize_by = "normalized_by_" + normalize if normalize else "count"
253-
254-
for ax, thresh in zip(axes, thresholds, strict=True):
255-
# Find the existing threshold that is closest to the given threshold
256-
closest_threshold = self.thresholds_[
257-
np.argmin(np.abs(self.thresholds_ - thresh))
258-
]
259-
cm = self.confusion_matrix[
260-
self.confusion_matrix["threshold"] == closest_threshold
261-
]
262-
263-
sns.heatmap(
264-
cm.pivot(
265-
index="True label", columns="Predicted label", values=normalize_by
266-
),
267-
ax=ax,
268-
**heatmap_kwargs_validated,
269-
)
270-
ax.set_title(f"threshold: {closest_threshold:.2f}")
271-
272-
self.ax_ = axes[-1] # Set ax_ to the last axes for consistency
273-
self.figure_.tight_layout()
274-
275200
@classmethod
276201
def _compute_data_for_display(
277202
cls,
@@ -342,30 +267,21 @@ def _compute_data_for_display(
342267

343268
confusion_matrix_records = []
344269
for cm, threshold_value in zip(cms, thresholds, strict=True):
345-
# Compute normalized values with proper handling of zero division
346-
with np.errstate(all="ignore"):
347-
row_sums = cm.sum(axis=1, keepdims=True)
348-
col_sums = cm.sum(axis=0, keepdims=True)
349-
total_sum = cm.sum()
350-
351-
cm_true = np.divide(
352-
cm,
353-
row_sums,
354-
out=np.zeros_like(cm, dtype=float),
355-
where=row_sums != 0,
356-
)
357-
cm_pred = np.divide(
358-
cm,
359-
col_sums,
360-
out=np.zeros_like(cm, dtype=float),
361-
where=col_sums != 0,
362-
)
363-
cm_all = np.divide(
364-
cm,
365-
total_sum,
366-
out=np.zeros_like(cm, dtype=float),
367-
where=total_sum != 0,
368-
)
270+
cm_true = np.divide(
271+
cm,
272+
cm.sum(axis=1, keepdims=True),
273+
where=cm.sum(axis=1, keepdims=True) != 0,
274+
)
275+
cm_pred = np.divide(
276+
cm,
277+
cm.sum(axis=0, keepdims=True),
278+
where=cm.sum(axis=0, keepdims=True) != 0,
279+
)
280+
cm_all = np.divide(
281+
cm,
282+
cm.sum(),
283+
where=cm.sum() != 0,
284+
)
369285

370286
n_classes = len(display_labels)
371287
true_labels = np.repeat(display_labels, n_classes)
@@ -438,7 +354,7 @@ def frame(
438354

439355
if threshold is not None and not self.do_threshold:
440356
raise ValueError(
441-
"threshold can only be used with binary classification "
357+
"threshold can only be used with binary classification and "
442358
"when `report.metrics.confusion_matrix(threshold=True)` is used."
443359
)
444360
elif threshold is None and self.do_threshold:

skore/tests/unit/displays/confusion_matrix/test_estimator.py

Lines changed: 42 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,8 @@ def test_threshold_display_creation(
341341
def test_threshold_display_without_threshold(
342342
pyplot, logistic_binary_classification_with_train_test
343343
):
344-
"""Check that do_threshold is False when threshold=False."""
344+
"""Check that do_threshold is False when threshold=False and that we raise an error
345+
when frame or plot is called with threshold."""
345346
estimator, X_train, X_test, y_train, y_test = (
346347
logistic_binary_classification_with_train_test
347348
)
@@ -357,6 +358,18 @@ def test_threshold_display_without_threshold(
357358
assert display.do_threshold is False
358359
assert display.thresholds_ is None
359360

361+
display = report.metrics.confusion_matrix(threshold=False)
362+
363+
err_msg = (
364+
"threshold can only be used with binary classification and "
365+
"when `report.metrics.confusion_matrix\\(threshold=True\\)` is used."
366+
)
367+
with pytest.raises(ValueError, match=err_msg):
368+
display.frame(threshold=0.5)
369+
370+
with pytest.raises(ValueError, match=err_msg):
371+
display.plot(threshold=0.5)
372+
360373

361374
def test_plot_with_threshold(pyplot, logistic_binary_classification_with_train_test):
362375
"""Check that we can plot with a specific threshold."""
@@ -375,9 +388,6 @@ def test_plot_with_threshold(pyplot, logistic_binary_classification_with_train_t
375388
display.plot(threshold=0.3)
376389
assert "threshold" in display.ax_.get_title().lower()
377390

378-
display.plot(threshold=0.7)
379-
assert "threshold" in display.ax_.get_title().lower()
380-
381391

382392
def test_plot_with_default_threshold(
383393
pyplot, logistic_binary_classification_with_train_test
@@ -394,34 +404,15 @@ def test_plot_with_default_threshold(
394404
y_test=y_test,
395405
)
396406
display = report.metrics.confusion_matrix(threshold=True)
397-
display.plot() # Should use default threshold (0.5)
398-
399-
# The title should include the threshold
400-
assert "threshold" in display.ax_.get_title().lower()
401-
402-
403-
def test_threshold_error_without_threshold_support(
404-
pyplot, forest_binary_classification_with_train_test
405-
):
406-
"""Check that we raise an error when threshold is used without threshold support."""
407-
estimator, X_train, X_test, y_train, y_test = (
408-
forest_binary_classification_with_train_test
409-
)
410-
report = EstimatorReport(
411-
estimator,
412-
X_train=X_train,
413-
y_train=y_train,
414-
X_test=X_test,
415-
y_test=y_test,
416-
)
417-
display = report.metrics.confusion_matrix(threshold=False)
407+
display.plot()
418408

419-
err_msg = (
420-
"threshold can only be used with binary classification and "
421-
"when `report.metrics.confusion_matrix\\(threshold=True\\)` is used."
409+
closest_threshold = display.thresholds_[
410+
np.argmin(np.abs(display.thresholds_ - 0.5))
411+
]
412+
assert (
413+
display.ax_.get_title()
414+
== f"Confusion Matrix (threshold: {closest_threshold:.2f})"
422415
)
423-
with pytest.raises(ValueError, match=err_msg):
424-
display.plot(threshold=0.5)
425416

426417

427418
def test_frame_with_threshold(logistic_binary_classification_with_train_test):
@@ -463,30 +454,6 @@ def test_frame_all_thresholds(logistic_binary_classification_with_train_test):
463454
assert len(frame) == len(display.thresholds_)
464455

465456

466-
def test_frame_threshold_error_without_threshold_support(
467-
forest_binary_classification_with_train_test,
468-
):
469-
"""Check that we raise an error when threshold is used without threshold support."""
470-
estimator, X_train, X_test, y_train, y_test = (
471-
forest_binary_classification_with_train_test
472-
)
473-
report = EstimatorReport(
474-
estimator,
475-
X_train=X_train,
476-
y_train=y_train,
477-
X_test=X_test,
478-
y_test=y_test,
479-
)
480-
display = report.metrics.confusion_matrix(threshold=False)
481-
482-
err_msg = (
483-
"threshold can only be used with binary classification "
484-
"when `report.metrics.confusion_matrix\\(threshold=True\\)` is used."
485-
)
486-
with pytest.raises(ValueError, match=err_msg):
487-
display.frame(threshold=0.5)
488-
489-
490457
def test_threshold_normalization(
491458
pyplot, logistic_binary_classification_with_train_test
492459
):
@@ -503,26 +470,23 @@ def test_threshold_normalization(
503470
)
504471
display = report.metrics.confusion_matrix(threshold=True)
505472

506-
# Test with normalize="true"
507473
display.plot(threshold=0.5, normalize="true")
508474
frame = display.frame(threshold=0.5, normalize="true")
509475
assert np.allclose(frame.sum(axis=1), np.ones(2))
510476

511-
# Test with normalize="pred"
512477
display.plot(threshold=0.5, normalize="pred")
513478
frame = display.frame(threshold=0.5, normalize="pred")
514479
assert np.allclose(frame.sum(axis=0), np.ones(2))
515480

516-
# Test with normalize="all"
517481
display.plot(threshold=0.5, normalize="all")
518482
frame = display.frame(threshold=0.5, normalize="all")
519483
assert np.isclose(frame.sum().sum(), 1.0)
520484

521485

522-
def test_plot_with_multiple_thresholds(
486+
def test_threshold_closest_match(
523487
pyplot, logistic_binary_classification_with_train_test
524488
):
525-
"""Check that we can plot with multiple thresholds."""
489+
"""Check that the closest threshold is selected."""
526490
estimator, X_train, X_test, y_train, y_test = (
527491
logistic_binary_classification_with_train_test
528492
)
@@ -535,17 +499,25 @@ def test_plot_with_multiple_thresholds(
535499
)
536500
display = report.metrics.confusion_matrix(threshold=True)
537501

538-
# Plot with multiple thresholds
539-
display.plot(threshold=[0.3, 0.5, 0.7])
540-
541-
# Should have 3 subplots
542-
assert len(display.figure_.axes) >= 3
502+
# Create a threshold that is not in the list to test the closest match
503+
middle_index = len(display.thresholds_) // 2
504+
threshold = (
505+
display.thresholds_[middle_index] + display.thresholds_[middle_index + 1]
506+
) / 2 - 1e-6
507+
closest_threshold = display.thresholds_[middle_index]
508+
assert threshold not in display.thresholds_
509+
display.plot(threshold=threshold)
510+
assert (
511+
display.ax_.get_title()
512+
== f"Confusion Matrix (threshold: {closest_threshold:.2f})"
513+
)
543514

544515

545-
def test_threshold_closest_match(
516+
def test_frame_plot_coincidence_with_threshold(
546517
pyplot, logistic_binary_classification_with_train_test
547518
):
548-
"""Check that the closest threshold is selected."""
519+
"""Check that the values in the frame and plot coincide when threshold is
520+
provided."""
549521
estimator, X_train, X_test, y_train, y_test = (
550522
logistic_binary_classification_with_train_test
551523
)
@@ -557,7 +529,7 @@ def test_threshold_closest_match(
557529
y_test=y_test,
558530
)
559531
display = report.metrics.confusion_matrix(threshold=True)
560-
561-
# Even with a threshold not in the list, it should work
562-
display.plot(threshold=0.12345)
563-
assert display.ax_ is not None
532+
frame = display.frame(threshold=0.5)
533+
frame_values = frame.values.flatten()
534+
display.plot(threshold=0.5)
535+
assert np.allclose(frame_values, display.ax_.collections[0].get_array().flatten())

0 commit comments

Comments
 (0)