@@ -73,7 +73,7 @@ def plot(
7373 self ,
7474 * ,
7575 normalize : Literal ["true" , "pred" , "all" ] | None = None ,
76- threshold : float | None = None ,
76+ threshold : float | list [ float ] | None = None ,
7777 heatmap_kwargs : dict | None = None ,
7878 ):
7979 """Plot visualization.
@@ -85,10 +85,11 @@ def plot(
8585 conditions or all the population. If None, the confusion matrix will not be
8686 normalized.
8787
88- threshold : float or None, default=None
89- The decision threshold to use for binary classification. If None,
88+ threshold : float, list of float, or None, default=None
89+ The decision threshold(s) to use for binary classification. If None,
9090 uses the default threshold (0.5 if thresholds are available, or the
91- predicted labels if not).
91+ predicted labels if not). If a list of floats is provided, multiple
92+ confusion matrices will be plotted side by side.
9293
9394 heatmap_kwargs : dict, default=None
9495 Additional keyword arguments to be passed to seaborn's `sns.heatmap`.
@@ -108,7 +109,7 @@ def _plot_matplotlib(
108109 self ,
109110 * ,
110111 normalize : Literal ["true" , "pred" , "all" ] | None = None ,
111- threshold : float | None = None ,
112+ threshold : float | list [ float ] | None = None ,
112113 heatmap_kwargs : dict | None = None ,
113114 ) -> None :
114115 """Matplotlib implementation of the `plot` method."""
@@ -128,7 +129,7 @@ def _plot_single_estimator(
128129 self ,
129130 * ,
130131 normalize : Literal ["true" , "pred" , "all" ] | None = None ,
131- threshold : float | None = None ,
132+ threshold : float | list [ float ] | None = None ,
132133 heatmap_kwargs : dict | None = None ,
133134 ) -> None :
134135 """
@@ -141,14 +142,29 @@ def _plot_single_estimator(
141142 conditions or all the population. If None, the confusion matrix will not be
142143 normalized.
143144
144- threshold : float or None, default=None
145- The decision threshold to use for binary classification. If None,
145+ threshold : float, list of float, or None, default=None
146+ The decision threshold(s) to use for binary classification. If None,
146147 uses the default threshold (0.5 if thresholds are available, or the
147- predicted labels if not).
148+ predicted labels if not). If a list of floats is provided, multiple
149+ confusion matrices will be plotted side by side.
148150
149151 heatmap_kwargs : dict, default=None
150152 Additional keyword arguments to be passed to seaborn's `sns.heatmap`.
151153 """
154+ # Handle multiple thresholds
155+ if isinstance (threshold , (list , tuple )):
156+ if not self .do_threshold :
157+ raise ValueError (
158+ "threshold can only be used with binary classification and "
159+ "when `report.metrics.confusion_matrix(threshold=True)` is used."
160+ )
161+ self ._plot_multiple_thresholds (
162+ thresholds = threshold ,
163+ normalize = normalize ,
164+ heatmap_kwargs = heatmap_kwargs ,
165+ )
166+ return
167+
152168 if threshold is not None :
153169 if not self .do_threshold :
154170 raise ValueError (
@@ -189,6 +205,67 @@ def _plot_single_estimator(
189205 self .ax_ .set_title (title )
190206 self .figure_ .tight_layout ()
191207
208+ def _plot_multiple_thresholds (
209+ self ,
210+ * ,
211+ thresholds : list [float ],
212+ normalize : Literal ["true" , "pred" , "all" ] | None = None ,
213+ heatmap_kwargs : dict | None = None ,
214+ ) -> None :
215+ """
216+ Plot multiple confusion matrices for different thresholds.
217+
218+ Parameters
219+ ----------
220+ thresholds : list of float
221+ The decision thresholds to use.
222+
223+ normalize : {'true', 'pred', 'all'}, default=None
224+ Normalizes confusion matrix over the true (rows), predicted (columns)
225+ conditions or all the population. If None, the confusion matrix will not be
226+ normalized.
227+
228+ heatmap_kwargs : dict, default=None
229+ Additional keyword arguments to be passed to seaborn's `sns.heatmap`.
230+ """
231+ n_thresholds = len (thresholds )
232+ figsize = (5 * n_thresholds , 4 )
233+ self .figure_ , axes = plt .subplots (1 , n_thresholds , figsize = figsize )
234+
235+ # Handle single threshold case (axes won't be an array)
236+ if n_thresholds == 1 :
237+ axes = [axes ]
238+
239+ heatmap_kwargs_validated = _validate_style_kwargs (
240+ {"fmt" : ".2f" if normalize else "d" , ** self ._default_heatmap_kwargs },
241+ heatmap_kwargs or {},
242+ )
243+ # Disable colorbar for multi-threshold plots to avoid clutter
244+ heatmap_kwargs_validated ["cbar" ] = False
245+
246+ normalize_by = "normalized_by_" + normalize if normalize else "count"
247+
248+ for ax , thresh in zip (axes , thresholds , strict = True ):
249+ # Find the existing threshold that is closest to the given threshold
250+ closest_threshold = self .thresholds_ [
251+ np .argmin (np .abs (self .thresholds_ - thresh ))
252+ ]
253+ cm = self .confusion_matrix [
254+ self .confusion_matrix ["threshold" ] == closest_threshold
255+ ]
256+
257+ sns .heatmap (
258+ cm .pivot (
259+ index = "True label" , columns = "Predicted label" , values = normalize_by
260+ ),
261+ ax = ax ,
262+ ** heatmap_kwargs_validated ,
263+ )
264+ ax .set_title (f"threshold: { closest_threshold :.2f} " )
265+
266+ self .ax_ = axes [- 1 ] # Set ax_ to the last axes for consistency
267+ self .figure_ .tight_layout ()
268+
192269 @classmethod
193270 def _compute_data_for_display (
194271 cls ,
@@ -258,9 +335,31 @@ def _compute_data_for_display(
258335
259336 confusion_matrix_records = []
260337 for cm , threshold_value in zip (cms , thresholds , strict = True ):
261- cm_true = cm / cm .sum (axis = 1 , keepdims = True )
262- cm_pred = cm / cm .sum (axis = 0 , keepdims = True )
263- cm_all = cm / cm .sum ()
338+ # Compute normalized values with proper handling of zero division
339+ with np .errstate (all = "ignore" ):
340+ row_sums = cm .sum (axis = 1 , keepdims = True )
341+ col_sums = cm .sum (axis = 0 , keepdims = True )
342+ total_sum = cm .sum ()
343+
344+ cm_true = np .divide (
345+ cm ,
346+ row_sums ,
347+ out = np .zeros_like (cm , dtype = float ),
348+ where = row_sums != 0 ,
349+ )
350+ cm_pred = np .divide (
351+ cm ,
352+ col_sums ,
353+ out = np .zeros_like (cm , dtype = float ),
354+ where = col_sums != 0 ,
355+ )
356+ cm_all = np .divide (
357+ cm ,
358+ total_sum ,
359+ out = np .zeros_like (cm , dtype = float ),
360+ where = total_sum != 0 ,
361+ )
362+
264363 n_classes = len (display_labels )
265364 true_labels = np .repeat (display_labels , n_classes )
266365 pred_labels = np .tile (display_labels , n_classes )
0 commit comments