diff --git a/examples/model_evaluation/plot_estimator_report.py b/examples/model_evaluation/plot_estimator_report.py index 63c022ee86..c5405cff0f 100644 --- a/examples/model_evaluation/plot_estimator_report.py +++ b/examples/model_evaluation/plot_estimator_report.py @@ -397,14 +397,12 @@ def operational_decision_cost(y_true, y_pred, amount): # %% # We can normalize the confusion matrix to get percentages instead of raw counts. # Here we normalize by true labels (rows): -cm_display = report.metrics.confusion_matrix() cm_display.plot(normalize="true") plt.show() # %% # More plotting options are available via ``heatmap_kwargs``, which are passed to # seaborn's heatmap. For example, we can customize the colormap and number format: -cm_display = report.metrics.confusion_matrix() cm_display.plot(heatmap_kwargs={"cmap": "Greens", "fmt": ".2e"}) plt.show() @@ -414,6 +412,36 @@ def operational_decision_cost(y_true, y_pred, amount): cm_frame = cm_display.frame() cm_frame +# %% +# Decision threshold support +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# For binary classification, the confusion matrix can be computed at different +# decision thresholds. This is useful for understanding how the model's predictions +# change as the decision threshold varies. + +# %% +# First, we create a display with threshold support enabled: +cm_threshold_display = report.metrics.confusion_matrix(threshold=True) + +# %% +# Now we can plot the confusion matrix at a specific threshold: +cm_threshold_display.plot(threshold_value=0.3) +plt.show() + +# %% +# Since there are a finite number of threshold where the predictions change, +# we plot the decision matrix associated with the threshold closest to the one provided. +# +# The frame method also supports threshold selection: +cm_threshold_display.frame(threshold_value=0.7) + +# %% +# When `threshold_value` is set to `"all"`, we get all confusion matrices for all +# available thresholds: +cm_all_thresholds = cm_threshold_display.frame(threshold_value="all") +cm_all_thresholds.head(5) + # %% # .. seealso:: # diff --git a/skore/src/skore/_externals/_sklearn_compat.py b/skore/src/skore/_externals/_sklearn_compat.py index 0b981b4749..05db00ec7d 100644 --- a/skore/src/skore/_externals/_sklearn_compat.py +++ b/skore/src/skore/_externals/_sklearn_compat.py @@ -869,3 +869,20 @@ def parametrize_with_checks( check_X_y, # noqa: F401 validate_data, # noqa: F401 ) + + +######################################################################################## +# Upgrading for scikit-learn 1.8 +######################################################################################## + +if sklearn_version < parse_version("1.8"): + + def confusion_matrix_at_thresholds(y_true, y_score, pos_label): + from sklearn.metrics._ranking import _binary_clf_curve + + fps, tps, thresholds = _binary_clf_curve(y_true, y_score, pos_label=pos_label) + fns = tps[-1] - tps + tns = fps[-1] - fps + return tns, fps, fns, tps, thresholds +else: + from sklearn.metrics._ranking import confusion_matrix_at_thresholds # noqa: F401 diff --git a/skore/src/skore/_sklearn/_estimator/metrics_accessor.py b/skore/src/skore/_sklearn/_estimator/metrics_accessor.py index 68bbdaf9ea..5b97650d3e 100644 --- a/skore/src/skore/_sklearn/_estimator/metrics_accessor.py +++ b/skore/src/skore/_sklearn/_estimator/metrics_accessor.py @@ -2059,6 +2059,7 @@ def confusion_matrix( data_source: DataSource = "test", X: ArrayLike | None = None, y: ArrayLike | None = None, + pos_label: PositiveLabel | None = _DEFAULT, ) -> ConfusionMatrixDisplay: """Plot the confusion matrix. @@ -2082,6 +2083,11 @@ def confusion_matrix( New target on which to compute the metric. By default, we use the target provided when creating the report. + pos_label : int, float, bool, str or None, default=_DEFAULT + The label to consider as the positive class when displaying the matrix. Use + this parameter to override the positive class. By default, the positive + class is set to the one provided when creating the report. + Returns ------- display : :class:`~skore._sklearn._plot.ConfusionMatrixDisplay` @@ -2097,16 +2103,35 @@ def confusion_matrix( >>> split_data = train_test_split(X=X, y=y, random_state=0, as_dict=True) >>> classifier = LogisticRegression(max_iter=10_000) >>> report = EstimatorReport(classifier, **split_data) - >>> report.metrics.confusion_matrix() + >>> display = report.metrics.confusion_matrix() + >>> display.plot() + + With specific threshold for binary classification: + + >>> display = report.metrics.confusion_matrix() + >>> display.plot(threshold_value=0.7) """ - display_kwargs = {"display_labels": self._parent.estimator_.classes_.tolist()} + if pos_label is _DEFAULT: + pos_label = self._parent.pos_label + + display_kwargs = { + "display_labels": self._parent.estimator_.classes_.tolist(), + "pos_label": pos_label, + } + + response_method: str | list[str] | tuple[str, ...] + if self._parent._ml_task == "binary-classification": + response_method = ("predict_proba", "decision_function") + else: + response_method = "predict" + display = cast( ConfusionMatrixDisplay, self._get_display( X=X, y=y, data_source=data_source, - response_method="predict", + response_method=response_method, display_class=ConfusionMatrixDisplay, display_kwargs=display_kwargs, ), diff --git a/skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py b/skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py index b4192e32ab..d8aa65e5c6 100644 --- a/skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py +++ b/skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py @@ -5,11 +5,13 @@ import numpy as np import pandas as pd import seaborn as sns +from numpy._typing import NDArray from sklearn.metrics import confusion_matrix as sklearn_confusion_matrix +from skore._externals._sklearn_compat import confusion_matrix_at_thresholds from skore._sklearn._plot.base import DisplayMixin from skore._sklearn._plot.utils import _validate_style_kwargs -from skore._sklearn.types import ReportType, YPlotData +from skore._sklearn.types import MLTask, PositiveLabel, ReportType, YPlotData class ConfusionMatrixDisplay(DisplayMixin): @@ -20,7 +22,8 @@ class ConfusionMatrixDisplay(DisplayMixin): confusion_matrix : pd.DataFrame Confusion matrix data in long format with columns: "True label", "Predicted label", "count", "normalized_by_true", "normalized_by_pred", - and "normalized_by_all". Each row represents one cell of the confusion matrix. + "normalized_by_all" and "threshold". Each row represents one cell of one + confusion matrix. display_labels : list of str Display labels for plot axes. @@ -29,8 +32,19 @@ class ConfusionMatrixDisplay(DisplayMixin): "cross-validation", "estimator"} The type of report. + ml_task : {"binary-classification", "multiclass-classification"} + The machine learning task. + + pos_label : int, float, bool, str or None + The class considered as the positive class when displaying the confusion + matrix. + Attributes ---------- + thresholds_ : array-like of shape (n_thresholds,) + Thresholds of the decision function. Each threshold is associated with a + confusion matrix. Only available for binary classification. + figure_ : matplotlib Figure Figure containing the confusion matrix. @@ -44,10 +58,16 @@ def __init__( confusion_matrix: pd.DataFrame, display_labels: list[str], report_type: ReportType, + ml_task: MLTask, + thresholds: NDArray, + pos_label: PositiveLabel, ): self.confusion_matrix = confusion_matrix self.display_labels = display_labels self.report_type = report_type + self.thresholds_ = thresholds + self.ml_task = ml_task + self.pos_label = pos_label _default_heatmap_kwargs: dict = { "cmap": "Blues", @@ -60,9 +80,15 @@ def plot( self, *, normalize: Literal["true", "pred", "all"] | None = None, + threshold_value: float | None = None, heatmap_kwargs: dict | None = None, ): - """Plot visualization. + """Plot the confusion matrix. + + In binary classification, the confusion matrix can be displayed at various + decision thresholds. This is useful for understanding how the model's + predictions change as the decision threshold varies. If no threshold is + provided, the confusion matrix is displayed at the default threshold (0.5). Parameters ---------- @@ -71,6 +97,11 @@ def plot( conditions or all the population. If None, the confusion matrix will not be normalized. + threshold_value : float or None, default=None + The decision threshold to use when applicable. + If None and thresholds are available, plots the confusion matrix at the + default threshold (0.5). + heatmap_kwargs : dict, default=None Additional keyword arguments to be passed to seaborn's `sns.heatmap`. @@ -81,6 +112,7 @@ def plot( """ return self._plot( normalize=normalize, + threshold_value=threshold_value, heatmap_kwargs=heatmap_kwargs, ) @@ -88,12 +120,14 @@ def _plot_matplotlib( self, *, normalize: Literal["true", "pred", "all"] | None = None, + threshold_value: float | None = None, heatmap_kwargs: dict | None = None, ) -> None: """Matplotlib implementation of the `plot` method.""" if self.report_type == "estimator": self._plot_single_estimator( normalize=normalize, + threshold_value=threshold_value, heatmap_kwargs=heatmap_kwargs, ) else: @@ -106,6 +140,7 @@ def _plot_single_estimator( self, *, normalize: Literal["true", "pred", "all"] | None = None, + threshold_value: float | None = None, heatmap_kwargs: dict | None = None, ) -> None: """ @@ -118,9 +153,13 @@ def _plot_single_estimator( conditions or all the population. If None, the confusion matrix will not be normalized. + threshold_value : float or None, default=None + The decision threshold to use when applicable. + If None and thresholds are available, plots the confusion matrix at the + default threshold (0.5). + heatmap_kwargs : dict, default=None Additional keyword arguments to be passed to seaborn's `sns.heatmap`. - """ self.figure_, self.ax_ = plt.subplots() @@ -128,20 +167,58 @@ def _plot_single_estimator( {"fmt": ".2f" if normalize else "d", **self._default_heatmap_kwargs}, heatmap_kwargs or {}, ) - normalize_by = "normalized_by_" + normalize if normalize else "count" + sns.heatmap( - self.confusion_matrix.pivot( - index="true_label", columns="predicted_label", values=normalize_by - ), + self.frame(threshold_value=threshold_value) + .pivot(index="true_label", columns="predicted_label", values=normalize_by) + .reindex(index=self.display_labels, columns=self.display_labels), ax=self.ax_, **heatmap_kwargs_validated, ) - self.ax_.set( - xlabel="Predicted label", - ylabel="True label", - title="Confusion Matrix", - ) + + title = "Confusion Matrix" + if self.ml_task == "binary-classification": + if threshold_value is None: + threshold_value = 0.5 + title = title + f"\nDecision threshold: {threshold_value:.2f}" + + if self.ml_task == "binary-classification" and self.pos_label is not None: + ticklabels = [ + f"{label}*" if label == str(self.pos_label) else label + for label in self.display_labels + ] + + self.ax_.set( + xlabel="Predicted label", + ylabel="True label", + title=title, + xticklabels=ticklabels, + yticklabels=ticklabels, + ) + + self.ax_.text( + -0.15, + -0.15, + "*: the positive class", + fontsize=9, + style="italic", + verticalalignment="bottom", + horizontalalignment="left", + transform=self.ax_.transAxes, + bbox={ + "boxstyle": "round", + "facecolor": "white", + "alpha": 0.8, + "edgecolor": "gray", + }, + ) + else: + self.ax_.set( + xlabel="Predicted label", + ylabel="True label", + title=title, + ) self.figure_.tight_layout() @@ -152,7 +229,9 @@ def _compute_data_for_display( y_pred: Sequence[YPlotData], *, report_type: ReportType, + ml_task: MLTask, display_labels: list[str], + pos_label: PositiveLabel, **kwargs, ) -> "ConfusionMatrixDisplay": """Compute the confusion matrix for display. @@ -163,19 +242,26 @@ def _compute_data_for_display( True labels. y_pred : list of array-like of shape (n_samples,) - Predicted labels, as returned by a classifier. + Decision scores when binary classification with thresholds enabled. + Otherwise, predicted labels. report_type : {"comparison-cross-validation", "comparison-estimator", \ "cross-validation", "estimator"} The type of report. + ml_task : {"binary-classification", "multiclass-classification"} + The machine learning task. + display_labels : list of str Display labels for plot. + pos_label : int, float, bool, str or None + The class considered as the positive class when displaying the confusion + matrix. + **kwargs : dict Additional keyword arguments that are ignored for compatibility with - other metrics displays. Accepts but ignores `estimators`, `ml_task`, - and `data_source`. + other metrics displays. Accepts but ignores `estimators` and `data_source`. Returns ------- @@ -185,25 +271,58 @@ def _compute_data_for_display( y_true_values = y_true[0].y y_pred_values = y_pred[0].y - cm = sklearn_confusion_matrix( - y_true=y_true_values, - y_pred=y_pred_values, - normalize=None, # we will normalize later - ) + if ml_task == "binary-classification": + if pos_label is not None: + neg_label = next( + label for label in display_labels if label != pos_label + ) + display_labels = [str(neg_label), str(pos_label)] + tns, fps, fns, tps, thresholds = confusion_matrix_at_thresholds( + y_true=y_true_values, + y_score=y_pred_values, + pos_label=pos_label, + ) + cms = np.column_stack([tns, fps, fns, tps]).reshape(-1, 2, 2).astype(int) + else: + cms = sklearn_confusion_matrix( + y_true=y_true_values, + y_pred=y_pred_values, + normalize=None, # we will normalize later + labels=display_labels, + )[np.newaxis, ...] + thresholds = np.array([np.nan]) + + row_sums = cms.sum(axis=2, keepdims=True) + cm_true = np.divide(cms, row_sums, where=row_sums != 0) + + col_sums = cms.sum(axis=1, keepdims=True) + cm_pred = np.divide(cms, col_sums, where=col_sums != 0) + + total_sums = cms.sum(axis=(1, 2), keepdims=True) + cm_all = np.divide(cms, total_sums, where=total_sums != 0) - cm_true = cm / cm.sum(axis=1, keepdims=True) - cm_pred = cm / cm.sum(axis=0, keepdims=True) - cm_all = cm / cm.sum() + n_thresholds = len(thresholds) n_classes = len(display_labels) + n_cells = n_classes * n_classes + + true_labels = np.tile(np.repeat(display_labels, n_classes), n_thresholds) + pred_labels = np.tile(np.tile(display_labels, n_classes), n_thresholds) + threshold_values = np.repeat(thresholds, n_cells) + + counts = cms.reshape(-1) + normalized_true_values = cm_true.reshape(-1) + normalized_pred_values = cm_pred.reshape(-1) + normalized_all_values = cm_all.reshape(-1) confusion_matrix = pd.DataFrame( { - "true_label": np.repeat(display_labels, n_classes), - "predicted_label": np.tile(display_labels, n_classes), - "count": cm.flatten(), - "normalized_by_true": cm_true.flatten(), - "normalized_by_pred": cm_pred.flatten(), - "normalized_by_all": cm_all.flatten(), + "true_label": true_labels, + "predicted_label": pred_labels, + "count": counts, + "normalized_by_true": normalized_true_values, + "normalized_by_pred": normalized_pred_values, + "normalized_by_all": normalized_all_values, + "threshold": threshold_values, } ) @@ -211,28 +330,56 @@ def _compute_data_for_display( confusion_matrix=confusion_matrix, display_labels=display_labels, report_type=report_type, + ml_task=ml_task, + pos_label=pos_label, + thresholds=np.unique(thresholds), ) return disp - def frame(self, normalize: Literal["true", "pred", "all"] | None = None): - """Return the confusion matrix as a dataframe. + def frame( + self, + threshold_value: float | None = None, + ): + """Return the confusion matrix as a long format dataframe. + + In binary classification, the confusion matrix can be returned at various + decision thresholds. This is useful for understanding how the model's + predictions change as the decision threshold varies. If no threshold is + provided, the default threshold (0.5) is used. + + The matrix is returned as a long format dataframe where each line represents one + cell of the matrix. The columns are "true_label", "predicted_label", "count", + "normalized_by_true", "normalized_by_pred", "normalized_by_all" and "threshold". Parameters ---------- - normalize : {'true', 'pred', 'all'}, default=None - Normalizes confusion matrix over the true (rows), predicted (columns) - conditions or all the population. If None, the confusion matrix will not be - normalized. + threshold_value : float or None, default=None + The decision threshold to use when applicable. + If None and thresholds are available, returns the confusion matrix at the + default threshold (0.5). Returns ------- frame : pandas.DataFrame - The confusion matrix as a dataframe in pivot format with true labels as - rows and predicted labels as columns. Values are counts or normalized - values depending on the `normalize` parameter. + The confusion matrix as a dataframe. """ - normalize_by = "normalized_by_" + normalize if normalize else "count" - return self.confusion_matrix.pivot( - index="true_label", columns="predicted_label", values=normalize_by - ) + if threshold_value is not None and self.ml_task != "binary-classification": + raise ValueError( + "Threshold support is only available for binary classification." + ) + if threshold_value is None: + if self.ml_task == "binary-classification": + threshold_value = 0.5 + else: + return self.confusion_matrix + + index_right = np.searchsorted(self.thresholds_, threshold_value) + index_left = index_right - 1 + diff_right = abs(self.thresholds_[index_right] - threshold_value) + diff_left = abs(self.thresholds_[index_left] - threshold_value) + + threshold_value = self.thresholds_[ + index_right if diff_right < diff_left else index_left + ] + return self.confusion_matrix.query("threshold == @threshold_value") diff --git a/skore/tests/unit/displays/confusion_matrix/test_estimator.py b/skore/tests/unit/displays/confusion_matrix/test_estimator.py index 05d2b2f573..b8233948e2 100644 --- a/skore/tests/unit/displays/confusion_matrix/test_estimator.py +++ b/skore/tests/unit/displays/confusion_matrix/test_estimator.py @@ -85,6 +85,7 @@ def test_confusion_matrix(pyplot, forest_binary_classification_with_train_test): "normalized_by_true", "normalized_by_pred", "normalized_by_all", + "threshold", ] n_classes = len(display.display_labels) assert display.confusion_matrix.shape[0] == (n_classes * n_classes) @@ -312,3 +313,245 @@ def test_not_implemented_error_for_non_estimator_report( ) with pytest.raises(NotImplementedError, match=err_msg): display.plot() + + +def test_threshold_display_creation( + pyplot, forest_binary_classification_with_train_test +): + """Check that we can create a confusion matrix display with threshold support.""" + estimator, X_train, X_test, y_train, y_test = ( + forest_binary_classification_with_train_test + ) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.confusion_matrix(threshold=True) + + assert isinstance(display, ConfusionMatrixDisplay) + assert display.threshold is True + assert display.thresholds_ is not None + assert len(display.thresholds_) > 0 + assert "threshold" in display.confusion_matrix.columns + + +def test_threshold_display_without_threshold( + pyplot, forest_binary_classification_with_train_test +): + """Check that threshold is False when threshold=False and that we raise an error + when frame or plot is called with threshold_value.""" + estimator, X_train, X_test, y_train, y_test = ( + forest_binary_classification_with_train_test + ) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.confusion_matrix(threshold=False) + + assert display.threshold is False + assert display.thresholds_ is None + + display = report.metrics.confusion_matrix(threshold=False) + + for func in ["frame", "plot"]: + err_msg = ( + "Enable threshold support by passing `threshold=True` to " + f"`report.metrics.confusion_matrix\\(\\)` before calling `{func}\\(\\)` " + "with `threshold_value`. This is only applicable for binary " + "classification." + ) + with pytest.raises(ValueError, match=err_msg): + getattr(display, func)(threshold_value=0.5) + + +def test_plot_with_threshold(pyplot, forest_binary_classification_with_train_test): + """Check that we can plot with a specific threshold.""" + estimator, X_train, X_test, y_train, y_test = ( + forest_binary_classification_with_train_test + ) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.confusion_matrix(threshold=True) + + display.plot(threshold_value=0.3) + assert "threshold" in display.ax_.get_title().lower() + + +def test_plot_with_default_threshold( + pyplot, forest_binary_classification_with_train_test +): + """Check that the default threshold (0.5) is used when not specified.""" + estimator, X_train, X_test, y_train, y_test = ( + forest_binary_classification_with_train_test + ) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.confusion_matrix(threshold=True) + display.plot() + + closest_threshold = display.thresholds_[ + np.argmin(np.abs(display.thresholds_ - 0.5)) + ] + assert ( + display.ax_.get_title() + == f"Confusion Matrix (threshold: {closest_threshold:.2f})" + ) + default_vals = display.ax_.collections[0].get_array().flatten() + + display.plot(threshold_value=0.5) + np.testing.assert_allclose( + default_vals, display.ax_.collections[0].get_array().flatten() + ) + + +def test_frame_with_threshold(forest_binary_classification_with_train_test): + """Check that we can get a frame at a specific threshold.""" + estimator, X_train, X_test, y_train, y_test = ( + forest_binary_classification_with_train_test + ) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.confusion_matrix(threshold=True) + frame = display.frame(threshold_value=0.5) + + assert isinstance(frame, pd.DataFrame) + assert frame.shape == (2, 2) + + +def test_frame_all_thresholds(forest_binary_classification_with_train_test): + """Check that we get all thresholds when threshold_value='all'.""" + estimator, X_train, X_test, y_train, y_test = ( + forest_binary_classification_with_train_test + ) + report = EstimatorReport( + estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + display = report.metrics.confusion_matrix(threshold=True) + frame = display.frame(threshold_value="all") + + assert isinstance(frame, pd.DataFrame) + assert "threshold" in frame.columns + assert len(frame) == len(display.thresholds_) + + +def test_frame_without_threshold(forest_binary_classification_with_train_test): + """Check that we get the default confusion matrix when threshold_value is None.""" + estimator, X_train, X_test, y_train, y_test = ( + forest_binary_classification_with_train_test + ) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.confusion_matrix(threshold=True) + frame = display.frame(threshold_value=None) + + assert isinstance(frame, pd.DataFrame) + assert "threshold" not in frame.columns + assert frame.shape == (estimator.classes_.size, estimator.classes_.size) + + +def test_threshold_normalization(pyplot, forest_binary_classification_with_train_test): + """Check that normalization works with threshold support.""" + estimator, X_train, X_test, y_train, y_test = ( + forest_binary_classification_with_train_test + ) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.confusion_matrix(threshold=True) + + display.plot(threshold_value=0.5, normalize="true") + display_vals = display.ax_.collections[0].get_array() + frame = display.frame(threshold_value=0.5, normalize="true") + np.testing.assert_allclose(display_vals.sum(axis=1), 1.0) + np.testing.assert_allclose(frame.sum(axis=1), 1.0) + + display.plot(threshold_value=0.5, normalize="pred") + display_vals = display.ax_.collections[0].get_array() + frame = display.frame(threshold_value=0.5, normalize="pred") + np.testing.assert_allclose(display_vals.sum(axis=0), 1.0) + np.testing.assert_allclose(frame.sum(axis=0), 1.0) + + display.plot(threshold_value=0.5, normalize="all") + display_vals = display.ax_.collections[0].get_array() + frame = display.frame(threshold_value=0.5, normalize="all") + assert display_vals.sum() == pytest.approx(1.0) + assert frame.sum().sum() == pytest.approx(1.0) + + +def test_threshold_closest_match(pyplot, forest_binary_classification_with_train_test): + """Check that the closest threshold is selected.""" + estimator, X_train, X_test, y_train, y_test = ( + forest_binary_classification_with_train_test + ) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.confusion_matrix(threshold=True) + + # Create a threshold that is not in the list to test the closest match + middle_index = len(display.thresholds_) // 2 + threshold = ( + display.thresholds_[middle_index] + display.thresholds_[middle_index + 1] + ) / 2 - 1e-6 + closest_threshold = display.thresholds_[middle_index] + assert threshold not in display.thresholds_ + display.plot(threshold_value=threshold) + assert ( + display.ax_.get_title() + == f"Confusion Matrix (threshold: {closest_threshold:.2f})" + ) + np.testing.assert_allclose( + display.ax_.collections[0].get_array(), + display.frame(threshold_value=closest_threshold).values, + ) + + +def test_frame_plot_coincidence_with_threshold( + pyplot, forest_binary_classification_with_train_test +): + """Check that the values in the frame and plot coincide when threshold is + provided.""" + estimator, X_train, X_test, y_train, y_test = ( + forest_binary_classification_with_train_test + ) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.confusion_matrix(threshold=True) + frame = display.frame(threshold_value=0.5) + frame_values = frame.values.flatten() + display.plot(threshold_value=0.5) + np.testing.assert_allclose( + frame_values, display.ax_.collections[0].get_array().flatten() + ) + + +def test_pos_label(pyplot, forest_binary_classification_with_train_test): + """Check that the pos_label parameter works correctly.""" + estimator, X_train, X_test, y_train, y_test = ( + forest_binary_classification_with_train_test + ) + labels = np.array(["A", "B"], dtype=object) + y_train = labels[y_train] + y_test = labels[y_test] + estimator.fit(X_train, y_train) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.confusion_matrix(pos_label="A") + display.plot() + # check that the pos label is in second position on both axes + assert display.ax_.get_xticklabels()[1].get_text() == "A" + assert display.ax_.get_yticklabels()[1].get_text() == "A" + + display = report.metrics.confusion_matrix(pos_label="B") + display.plot() + assert display.ax_.get_xticklabels()[1].get_text() == "B" + assert display.ax_.get_yticklabels()[1].get_text() == "B"