diff --git a/skore/src/skore/sklearn/_plot/metrics/precision_recall_curve.py b/skore/src/skore/sklearn/_plot/metrics/precision_recall_curve.py index 0d7ccdce62..dbddf848a1 100644 --- a/skore/src/skore/sklearn/_plot/metrics/precision_recall_curve.py +++ b/skore/src/skore/sklearn/_plot/metrics/precision_recall_curve.py @@ -137,6 +137,11 @@ def _plot_single_estimator( *, estimator_name: str, pr_curve_kwargs: list[dict[str, Any]], + ax: Optional[Axes] = None, + subplots: bool = False, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + figsize: Optional[tuple[float, float]] = None, ) -> tuple[Axes, list[Line2D], Union[str, None]]: """Plot precision-recall curve for a single estimator. @@ -150,6 +155,24 @@ def _plot_single_estimator( binary case, we should have a single dict. In multiclass case, we should have a list of dicts, one per class. + ax : matplotlib.axes.Axes, default=None + The axes to plot on. If None, self.ax_ is used. + + subplots : bool, default=False + If True, plot each class on a separate subplot. + + nrows : int, default=None + Number of rows in the subplot grid. Only used when subplots=True. + If None, it will be computed based on ncols. + + ncols : int, default=None + Number of columns in the subplot grid. Only used when subplots=True. + If None, defaults to 2 for multiple plots, 1 for a single plot. + + figsize : tuple of float, default=None + Figure size (width, height) in inches. Only used when subplots=True. + If None, a default size will be determined based on the number of subplots. + Returns ------- ax : matplotlib.axes.Axes @@ -165,6 +188,116 @@ def _plot_single_estimator( lines: list[Line2D] = [] line_kwargs: dict[str, Any] = {"drawstyle": "steps-post"} + if subplots: + if self.ml_task == "multiclass-classification": + num_plots = len(self.precision) + else: + num_plots = 1 + + # Calculate grid dimensions + if nrows is None and ncols is None: + if num_plots == 1: + ncols = 1 + nrows = 1 + else: + ncols = min(2, num_plots) + nrows = (num_plots + ncols - 1) // ncols + elif nrows is None: + nrows = (num_plots + (ncols or 1) - 1) // (ncols or 1) + elif ncols is None: + ncols = (num_plots + (nrows or 1) - 1) // (nrows or 1) + + # Create figure and subplots + self.figure_ = plt.figure(figsize=figsize) + axes: list[Axes] = [] + for i in range(num_plots): + if i == 0: + ax = self.figure_.add_subplot(nrows, ncols, i + 1) + else: + ax = self.figure_.add_subplot( + nrows, ncols, i + 1, sharex=axes[0], sharey=axes[0] + ) + axes.append(ax) + + # Set self.ax_ to the first axis for backwards compatibility + self.ax_ = axes[0] if axes else None + + # Plot in each subplot + for idx, axi in enumerate(axes): + if self.ml_task == "multiclass-classification": + class_label = list(self.precision.keys())[idx] + class_precision = {class_label: [self.precision[class_label][0]]} + class_recall = {class_label: [self.recall[class_label][0]]} + class_avg_precision = { + class_label: [self.average_precision[class_label][0]] + } + class_display = PrecisionRecallCurveDisplay( + precision=class_precision, + recall=class_recall, + average_precision=class_avg_precision, + estimator_names=[estimator_name], + pos_label=class_label, + data_source=self.data_source, + ml_task="binary-classification", + report_type="estimator", + ) + class_display.ax_ = axi + class_display.figure_ = self.figure_ + _, class_lines, _ = class_display._plot_single_estimator( + estimator_name=estimator_name, + pr_curve_kwargs=[pr_curve_kwargs[idx]], + ax=axi, + ) + lines.extend(class_lines) + axi.set_title(f"Class: {class_label}") + else: + # Binary classification case + pos_label = cast(PositiveLabel, self.pos_label) + line_kwargs_validated = _validate_style_kwargs( + line_kwargs, pr_curve_kwargs[0] + ) + if self.data_source in ("train", "test"): + line_kwargs_validated["label"] = ( + f"{self.data_source.title()} set " + f"(AP = {self.average_precision[pos_label][0]:0.2f})" + ) + else: + line_kwargs_validated["label"] = ( + f"AP = {self.average_precision[pos_label][0]:0.2f}" + ) + + (line,) = axi.plot( + self.recall[pos_label][0], + self.precision[pos_label][0], + **line_kwargs_validated, + ) + lines.append(line) + axi.set_title(f"Model: {estimator_name}") + + # Set axis labels and limits + xlabel = "Recall" + ylabel = "Precision" + if self.pos_label is not None: + xlabel += f"\n(Positive label: {self.pos_label})" + ylabel += f"\n(Positive label: {self.pos_label})" + + axi.set( + xlabel=xlabel, + xlim=(-0.01, 1.01), + ylabel=ylabel, + ylim=(-0.01, 1.01), + aspect="equal", + ) + + axi.legend(bbox_to_anchor=(1.02, 1), title=estimator_name) + + return self.ax_, lines, None + + # Single plot case + plot_ax = ax if ax is not None else self.ax_ + if plot_ax is None: + _, plot_ax = plt.subplots() + if self.ml_task == "binary-classification": pos_label = cast(PositiveLabel, self.pos_label) @@ -181,7 +314,7 @@ def _plot_single_estimator( f"AP = {self.average_precision[pos_label][0]:0.2f}" ) - (line,) = self.ax_.plot( + (line,) = plot_ax.plot( self.recall[pos_label][0], self.precision[pos_label][0], **line_kwargs_validated, @@ -219,22 +352,27 @@ def _plot_single_estimator( f"AP = {average_precision_class:0.2f}" ) - (line,) = self.ax_.plot( + (line,) = plot_ax.plot( recall_class, precision_class, **line_kwargs_validated ) lines.append(line) - info_pos_label = None # irrelevant for multiclass + info_pos_label = None # not relevant for multiclass - self.ax_.legend(bbox_to_anchor=(1.02, 1), title=estimator_name) + plot_ax.legend(bbox_to_anchor=(1.02, 1), title=estimator_name) - return self.ax_, lines, info_pos_label + return plot_ax, lines, info_pos_label def _plot_cross_validated_estimator( self, *, estimator_name: str, pr_curve_kwargs: list[dict[str, Any]], + ax: Optional[Axes] = None, + subplots: bool = False, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + figsize: Optional[tuple[float, float]] = None, ) -> tuple[Axes, list[Line2D], Union[str, None]]: """Plot precision-recall curve for a cross-validated estimator. @@ -248,6 +386,24 @@ def _plot_cross_validated_estimator( precision-recall curves. The length of the list should match the number of curves to plot. + ax : matplotlib.axes.Axes, default=None + The axes to plot on. If None, self.ax_ is used. + + subplots : bool, default=False + If True, plot each fold or class on a separate subplot. + + nrows : int, default=None + Number of rows in the subplot grid. Only used when subplots=True. + If None, it will be computed based on ncols. + + ncols : int, default=None + Number of columns in the subplot grid. Only used when subplots=True. + If None, defaults to 2 for multiple plots, 1 for a single plot. + + figsize : tuple of float, default=None + Figure size (width, height) in inches. Only used when subplots=True. + If None, a default size will be determined based on the number of subplots. + Returns ------- ax : matplotlib.axes.Axes @@ -263,6 +419,126 @@ def _plot_cross_validated_estimator( lines: list[Line2D] = [] line_kwargs: dict[str, Any] = {"drawstyle": "steps-post"} + if subplots: + if self.ml_task == "binary-classification": + pos_label = cast(PositiveLabel, self.pos_label) + num_plots = len(self.precision[pos_label]) + else: # multiclass + num_plots = len(self.precision) + + # Calculate grid dimensions + if nrows is None and ncols is None: + if num_plots == 1: + ncols = 1 + nrows = 1 + else: + ncols = min(2, num_plots) + nrows = (num_plots + ncols - 1) // ncols + elif nrows is None: + nrows = (num_plots + (ncols or 1) - 1) // (ncols or 1) + elif ncols is None: + ncols = (num_plots + (nrows or 1) - 1) // (nrows or 1) + + # Create figure and subplots + self.figure_ = plt.figure(figsize=figsize) + axes: list[Axes] = [] + for i in range(num_plots): + if i == 0: + ax = self.figure_.add_subplot(nrows, ncols, i + 1) + else: + ax = self.figure_.add_subplot( + nrows, ncols, i + 1, sharex=axes[0], sharey=axes[0] + ) + axes.append(ax) + + # Set self.ax_ to the first axis for backwards compatibility + self.ax_ = axes[0] if axes else None + + # Plot in each subplot + for idx, axi in enumerate(axes): + if self.ml_task == "binary-classification": + # Plot just one fold in this subplot + pos_label = cast(PositiveLabel, self.pos_label) + fold_precision = {pos_label: [self.precision[pos_label][idx]]} + fold_recall = {pos_label: [self.recall[pos_label][idx]]} + fold_avg_precision = { + pos_label: [self.average_precision[pos_label][idx]] + } + fold_display = PrecisionRecallCurveDisplay( + precision=fold_precision, + recall=fold_recall, + average_precision=fold_avg_precision, + estimator_names=[estimator_name], + pos_label=self.pos_label, + data_source=self.data_source, + ml_task=self.ml_task, + report_type="estimator", + ) + fold_display.ax_ = axi + fold_display.figure_ = self.figure_ + _, fold_lines, _ = fold_display._plot_single_estimator( + estimator_name=estimator_name, + pr_curve_kwargs=[pr_curve_kwargs[idx]], + ax=axi, + ) + lines.extend(fold_lines) + axi.set_title(f"Fold #{idx + 1}") + else: # multiclass + # Plot one class in this subplot + class_label = list(self.precision.keys())[idx] + class_precision = {class_label: self.precision[class_label]} + class_recall = {class_label: self.recall[class_label]} + class_avg_precision = { + class_label: self.average_precision[class_label] + } + class_display = PrecisionRecallCurveDisplay( + precision=class_precision, + recall=class_recall, + average_precision=class_avg_precision, + estimator_names=[estimator_name], + pos_label=class_label, + data_source=self.data_source, + ml_task="binary-classification", + report_type="cross-validation", + ) + class_display.ax_ = axi + class_display.figure_ = self.figure_ + _, class_lines, _ = class_display._plot_cross_validated_estimator( + estimator_name=estimator_name, + pr_curve_kwargs=pr_curve_kwargs, + ax=axi, + ) + lines.extend(class_lines) + axi.set_title(f"Class: {class_label}") + + # Set axis labels and limits + xlabel = "Recall" + ylabel = "Precision" + if self.pos_label is not None: + xlabel += f"\n(Positive label: {self.pos_label})" + ylabel += f"\n(Positive label: {self.pos_label})" + + axi.set( + xlabel=xlabel, + xlim=(-0.01, 1.01), + ylabel=ylabel, + ylim=(-0.01, 1.01), + aspect="equal", + ) + + if self.data_source in ("train", "test"): + title = f"{estimator_name} on $\\bf{{{self.data_source}}}$ set" + else: + title = f"{estimator_name} on $\\bf{{external}}$ set" + axi.legend(bbox_to_anchor=(1.02, 1), title=title) + + return self.ax_, lines, None + + # Single plot case + plot_ax = ax if ax is not None else self.ax_ + if plot_ax is None: + _, plot_ax = plt.subplots() + if self.ml_task == "binary-classification": pos_label = cast(PositiveLabel, self.pos_label) for split_idx in range(len(self.precision[pos_label])): @@ -278,7 +554,7 @@ def _plot_cross_validated_estimator( f"(AP = {average_precision_split:0.2f})" ) - (line,) = self.ax_.plot( + (line,) = plot_ax.plot( recall_split, precision_split, **line_kwargs_validated ) lines.append(line) @@ -287,7 +563,7 @@ def _plot_cross_validated_estimator( f"\n(Positive label: {pos_label})" if pos_label is not None else "" ) else: # multiclass-classification - info_pos_label = None # irrelevant for multiclass + info_pos_label = None # not relevant for multiclass class_colors = sample_mpl_colormap( colormaps.get_cmap("tab10"), 10 if len(self.precision) < 10 else len(self.precision), @@ -319,7 +595,7 @@ def _plot_cross_validated_estimator( else: line_kwargs_validated["label"] = None - (line,) = self.ax_.plot( + (line,) = plot_ax.plot( recall_split, precision_split, **line_kwargs_validated ) lines.append(line) @@ -328,15 +604,20 @@ def _plot_cross_validated_estimator( title = f"{estimator_name} on $\\bf{{{self.data_source}}}$ set" else: title = f"{estimator_name} on $\\bf{{external}}$ set" - self.ax_.legend(bbox_to_anchor=(1.02, 1), title=title) + plot_ax.legend(bbox_to_anchor=(1.02, 1), title=title) - return self.ax_, lines, info_pos_label + return plot_ax, lines, info_pos_label def _plot_comparison_estimator( self, *, estimator_names: list[str], pr_curve_kwargs: list[dict[str, Any]], + ax: Optional[Axes] = None, + subplots: bool = False, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + figsize: Optional[tuple[float, float]] = None, ) -> tuple[Axes, list[Line2D], Union[str, None]]: """Plot precision-recall curve of several estimators. @@ -350,6 +631,24 @@ def _plot_comparison_estimator( precision-recall curves. The length of the list should match the number of curves to plot. + ax : matplotlib.axes.Axes, default=None + The axes to plot on. If None, self.ax_ is used. + + subplots : bool, default=False + If True, plot each estimator on a separate subplot. + + nrows : int, default=None + Number of rows in the subplot grid. Only used when subplots=True. + If None, it will be computed based on ncols. + + ncols : int, default=None + Number of columns in the subplot grid. Only used when subplots=True. + If None, defaults to 2 for multiple plots, 1 for a single plot. + + figsize : tuple of float, default=None + Figure size (width, height) in inches. Only used when subplots=True. + If None, a default size will be determined based on the number of subplots. + Returns ------- ax : matplotlib.axes.Axes @@ -365,6 +664,100 @@ def _plot_comparison_estimator( lines: list[Line2D] = [] line_kwargs: dict[str, Any] = {"drawstyle": "steps-post"} + if subplots: + num_plots = len(estimator_names) + + # Calculate grid dimensions + if nrows is None and ncols is None: + if num_plots == 1: + ncols = 1 + nrows = 1 + else: + ncols = min(2, num_plots) + nrows = (num_plots + ncols - 1) // ncols + elif nrows is None: + nrows = (num_plots + (ncols or 1) - 1) // (ncols or 1) + elif ncols is None: + ncols = (num_plots + (nrows or 1) - 1) // (nrows or 1) + + # Create figure and subplots + self.figure_ = plt.figure(figsize=figsize) + axes: list[Axes] = [] + for i in range(num_plots): + if i == 0: + ax = self.figure_.add_subplot(nrows, ncols, i + 1) + else: + ax = self.figure_.add_subplot( + nrows, ncols, i + 1, sharex=axes[0], sharey=axes[0] + ) + axes.append(ax) + + # Set self.ax_ to the first axis for backwards compatibility + self.ax_ = axes[0] if axes else None + + # Plot in each subplot + for idx, (est_name, axi) in enumerate(zip(estimator_names, axes)): + if self.ml_task == "binary-classification": + pos_label = cast(PositiveLabel, self.pos_label) + est_precision = {pos_label: [self.precision[pos_label][idx]]} + est_recall = {pos_label: [self.recall[pos_label][idx]]} + est_avg_precision = { + pos_label: [self.average_precision[pos_label][idx]] + } + else: # multiclass + # Extract data for this estimator across all classes + est_precision = {} + est_recall = {} + est_avg_precision = {} + for class_label in self.precision: + est_precision[class_label] = [self.precision[class_label][idx]] + est_recall[class_label] = [self.recall[class_label][idx]] + est_avg_precision[class_label] = [ + self.average_precision[class_label][idx] + ] + + est_display = PrecisionRecallCurveDisplay( + precision=est_precision, + recall=est_recall, + average_precision=est_avg_precision, + estimator_names=[est_name], + pos_label=self.pos_label, + data_source=self.data_source, + ml_task=self.ml_task, + report_type="estimator", + ) + est_display.ax_ = axi + est_display.figure_ = self.figure_ + _, est_lines, _ = est_display._plot_single_estimator( + estimator_name=est_name, + pr_curve_kwargs=[pr_curve_kwargs[idx]], + ax=axi, + ) + lines.extend(est_lines) + axi.set_title(f"Model: {est_name}") + + # Set axis labels and limits + xlabel = "Recall" + ylabel = "Precision" + if self.pos_label is not None: + xlabel += f"\n(Positive label: {self.pos_label})" + ylabel += f"\n(Positive label: {self.pos_label})" + + axi.set( + xlabel=xlabel, + xlim=(-0.01, 1.01), + ylabel=ylabel, + ylim=(-0.01, 1.01), + aspect="equal", + ) + + return self.ax_, lines, None + + # Single plot case + plot_ax = ax if ax is not None else self.ax_ + if plot_ax is None: + _, plot_ax = plt.subplots() + if self.ml_task == "binary-classification": pos_label = cast(PositiveLabel, self.pos_label) for est_idx, est_name in enumerate(estimator_names): @@ -378,7 +771,7 @@ def _plot_comparison_estimator( line_kwargs_validated["label"] = ( f"{est_name} (AP = {average_precision_est:0.2f})" ) - (line,) = self.ax_.plot( + (line,) = plot_ax.plot( recall_est, precision_est, **line_kwargs_validated ) lines.append(line) @@ -387,7 +780,7 @@ def _plot_comparison_estimator( f"\n(Positive label: {pos_label})" if pos_label is not None else "" ) else: # multiclass-classification - info_pos_label = None # irrelevant for multiclass + info_pos_label = None # not relevant for multiclass class_colors = sample_mpl_colormap( colormaps.get_cmap("tab10"), 10 if len(self.precision) < 10 else len(self.precision), @@ -414,17 +807,17 @@ def _plot_comparison_estimator( f"(AP = {average_precision_mean:0.2f})" ) - (line,) = self.ax_.plot( + (line,) = plot_ax.plot( recall_est_class, precision_est_class, **line_kwargs_validated ) lines.append(line) - self.ax_.legend( + plot_ax.legend( bbox_to_anchor=(1.02, 1), title=f"{self.ml_task.title()} on $\\bf{{{self.data_source}}}$ set", ) - return self.ax_, lines, info_pos_label + return plot_ax, lines, info_pos_label @StyleDisplayMixin.style_plot def plot( @@ -434,6 +827,10 @@ def plot( estimator_name: Optional[str] = None, pr_curve_kwargs: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None, despine: bool = True, + subplots: bool = False, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + figsize: Optional[tuple[float, float]] = None, ) -> None: """Plot visualization. @@ -456,6 +853,21 @@ def plot( despine : bool, default=True Whether to remove the top and right spines from the plot. + subplots : bool, default=False + If True, plot each estimator or fold on a separate subplot. + + nrows : int, default=None + Number of rows in the subplot grid. Only used when subplots=True. + If None, it will be computed based on ncols. + + ncols : int, default=None + Number of columns in the subplot grid. Only used when subplots=True. + If None, defaults to 2 for multiple plots, 1 for a single plot. + + figsize : tuple of float, default=None + Figure size (width, height) in inches. Only used when subplots=True. + If None, a default size will be determined based on the number of subplots. + Notes ----- The average precision (cf. :func:`~sklearn.metrics.average_precision_score`) @@ -479,8 +891,16 @@ def plot( >>> report = EstimatorReport(classifier, **split_data) >>> display = report.metrics.precision_recall() >>> display.plot(pr_curve_kwargs={"color": "tab:red"}) + + With subplots: + + >>> display.plot(subplots=True) """ - self.figure_, self.ax_ = (ax.figure, ax) if ax is not None else plt.subplots() + if ax is not None and subplots: + raise ValueError( + "Cannot specify both 'ax' and 'subplots=True'. " + "Either provide an axes object or use subplots, but not both." + ) if pr_curve_kwargs is None: pr_curve_kwargs = self._default_pr_curve_kwargs @@ -492,6 +912,12 @@ def plot( report_type=self.report_type, ) + # Initialize figure and axes for non-subplot case + if not subplots: + self.figure_, self.ax_ = ( + (ax.figure, ax) if ax is not None else plt.subplots() + ) + if self.report_type == "estimator": self.ax_, self.lines_, info_pos_label = self._plot_single_estimator( estimator_name=( @@ -500,6 +926,11 @@ def plot( else estimator_name ), pr_curve_kwargs=pr_curve_kwargs, + ax=ax, + subplots=subplots, + nrows=nrows, + ncols=ncols, + figsize=figsize, ) elif self.report_type == "cross-validation": self.ax_, self.lines_, info_pos_label = ( @@ -510,12 +941,22 @@ def plot( else estimator_name ), pr_curve_kwargs=pr_curve_kwargs, + ax=ax, + subplots=subplots, + nrows=nrows, + ncols=ncols, + figsize=figsize, ) ) elif self.report_type == "comparison-estimator": self.ax_, self.lines_, info_pos_label = self._plot_comparison_estimator( estimator_names=self.estimator_names, pr_curve_kwargs=pr_curve_kwargs, + ax=ax, + subplots=subplots, + nrows=nrows, + ncols=ncols, + figsize=figsize, ) else: raise ValueError( @@ -523,22 +964,25 @@ def plot( f"or 'comparison-estimator'. Got '{self.report_type}' instead." ) - xlabel = "Recall" - ylabel = "Precision" - if info_pos_label: - xlabel += info_pos_label - ylabel += info_pos_label - - self.ax_.set( - xlabel=xlabel, - xlim=(-0.01, 1.01), - ylabel=ylabel, - ylim=(-0.01, 1.01), - aspect="equal", - ) + if not subplots: + xlabel = "Recall" + ylabel = "Precision" + if info_pos_label: + xlabel += info_pos_label + ylabel += info_pos_label + + self.ax_.set( + xlabel=xlabel, + xlim=(-0.01, 1.01), + ylabel=ylabel, + ylim=(-0.01, 1.01), + aspect="equal", + ) + + if despine: + _despine_matplotlib_axis(self.ax_) - if despine: - _despine_matplotlib_axis(self.ax_) + return self.figure_ @classmethod def _compute_data_for_display( diff --git a/skore/src/skore/sklearn/_plot/metrics/prediction_error.py b/skore/src/skore/sklearn/_plot/metrics/prediction_error.py index e2d787f27a..62330411ca 100644 --- a/skore/src/skore/sklearn/_plot/metrics/prediction_error.py +++ b/skore/src/skore/sklearn/_plot/metrics/prediction_error.py @@ -172,12 +172,207 @@ def _validate_data_points_kwargs( return data_points_kwargs + @StyleDisplayMixin.style_plot + def plot( + self, + ax: Optional[Axes] = None, + *, + estimator_name: Optional[str] = None, + kind: Literal[ + "actual_vs_predicted", "residual_vs_predicted" + ] = "residual_vs_predicted", + data_points_kwargs: Optional[ + Union[dict[str, Any], list[dict[str, Any]]] + ] = None, + perfect_model_kwargs: Optional[dict[str, Any]] = None, + despine: bool = True, + subplots: bool = False, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + figsize: Optional[tuple[float, float]] = None, + ) -> None: + """Plot visualization. + + Extra keyword arguments will be passed to matplotlib's ``plot``. + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + estimator_name : str + Name of the estimator used to plot the prediction error. If `None`, + we used the inferred name from the estimator. + + kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ + default="residual_vs_predicted" + The type of plot to draw: + + - "actual_vs_predicted" draws the observed values (y-axis) vs. + the predicted values (x-axis). + - "residual_vs_predicted" draws the residuals, i.e. difference + between observed and predicted values, (y-axis) vs. the predicted + values (x-axis). + + data_points_kwargs : dict, default=None + Dictionary with keywords passed to the `matplotlib.pyplot.scatter` + call. + + perfect_model_kwargs : dict, default=None + Dictionary with keyword passed to the `matplotlib.pyplot.plot` + call to draw the optimal line. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + subplots : bool, default=False + If True, plot each estimator or fold on a separate subplot. + + nrows : int, default=None + Number of rows in the subplot grid. Only used when subplots=True. + If None, it will be computed based on ncols. + + ncols : int, default=None + Number of columns in the subplot grid. Only used when subplots=True. + If None, defaults to 2 for multiple plots, 1 for a single plot. + + figsize : tuple of float, default=None + Figure size (width, height) in inches. Only used when subplots=True. + If None, a default size will be determined based on the number of subplots. + + Examples + -------- + >>> from sklearn.datasets import load_diabetes + >>> from sklearn.linear_model import Ridge + >>> from skore import train_test_split + >>> from skore import EstimatorReport + >>> X, y = load_diabetes(return_X_y=True) + >>> split_data = train_test_split(X=X, y=y, random_state=0, as_dict=True) + >>> classifier = Ridge() + >>> report = EstimatorReport(classifier, **split_data) + >>> display = report.metrics.prediction_error() + >>> display.plot(kind="actual_vs_predicted") + + With subplots: + + >>> display.plot(kind="actual_vs_predicted", subplots=True) + """ + expected_kind = ("actual_vs_predicted", "residual_vs_predicted") + if kind not in expected_kind: + raise ValueError( + f"`kind` must be one of {', '.join(expected_kind)}. " + f"Got {kind!r} instead." + ) + + if ax is not None and subplots: + raise ValueError( + "Cannot specify both 'ax' and 'subplots=True'. " + "Either provide an axes object or use subplots, but not both." + ) + + if kind == "actual_vs_predicted": + xlabel, ylabel = "Predicted values", "Actual values" + else: # kind == "residual_vs_predicted" + xlabel, ylabel = "Predicted values", "Residuals (actual - predicted)" + + perfect_model_kwargs_validated = _validate_style_kwargs( + { + "color": "black", + "alpha": 0.7, + "linestyle": "--", + "label": "Perfect predictions", + }, + perfect_model_kwargs or self._default_perfect_model_kwargs or {}, + ) + + if data_points_kwargs is None: + data_points_kwargs = self._default_data_points_kwargs + data_points_kwargs = self._validate_data_points_kwargs( + data_points_kwargs=data_points_kwargs + ) + + if not subplots: + self.figure_, self.ax_ = ( + (ax.figure, ax) if ax is not None else plt.subplots() + ) + + if self.report_type == "estimator": + self.scatter_ = self._plot_single_estimator( + kind=kind, + estimator_name=( + self.estimator_names[0] + if estimator_name is None + else estimator_name + ), + samples_kwargs=data_points_kwargs, + ax=ax if not subplots else None, + subplots=subplots, + nrows=nrows, + ncols=ncols, + figsize=figsize, + perfect_model_kwargs=perfect_model_kwargs_validated, + xlabel=xlabel, + ylabel=ylabel, + despine=despine, + ) + elif self.report_type == "cross-validation": + self.scatter_ = self._plot_cross_validated_estimator( + kind=kind, + estimator_name=( + self.estimator_names[0] + if estimator_name is None + else estimator_name + ), + samples_kwargs=data_points_kwargs, + ax=ax if not subplots else None, + subplots=subplots, + nrows=nrows, + ncols=ncols, + figsize=figsize, + perfect_model_kwargs=perfect_model_kwargs_validated, + xlabel=xlabel, + ylabel=ylabel, + despine=despine, + ) + elif self.report_type == "comparison-estimator": + self.scatter_ = self._plot_comparison_estimator( + kind=kind, + estimator_names=self.estimator_names, + samples_kwargs=data_points_kwargs, + ax=ax if not subplots else None, + subplots=subplots, + nrows=nrows, + ncols=ncols, + figsize=figsize, + perfect_model_kwargs=perfect_model_kwargs_validated, + xlabel=xlabel, + ylabel=ylabel, + despine=despine, + ) + else: + raise ValueError( + f"`report_type` should be one of 'estimator', 'cross-validation', " + f"or 'comparison-estimator'. Got '{self.report_type}' instead." + ) + + return self.figure_ + def _plot_single_estimator( self, *, kind: Literal["actual_vs_predicted", "residual_vs_predicted"], estimator_name: str, samples_kwargs: list[dict[str, Any]], + ax: Optional[Axes] = None, + subplots: bool = False, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + figsize: Optional[tuple[float, float]] = None, + perfect_model_kwargs: dict[str, Any], + xlabel: str, + ylabel: str, + despine: bool = True, ) -> list[Artist]: """Plot the prediction error for a single estimator. @@ -192,6 +387,37 @@ def _plot_single_estimator( samples_kwargs : list of dict Keyword arguments for the scatter plot. + ax : matplotlib Axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + subplots : bool, default=False + If True, plot each estimator or fold on a separate subplot. + + nrows : int, default=None + Number of rows in the subplot grid. Only used when subplots=True. + If None, it will be computed based on ncols. + + ncols : int, default=None + Number of columns in the subplot grid. Only used when subplots=True. + If None, defaults to 2 for multiple plots, 1 for a single plot. + + figsize : tuple of float, default=None + Figure size (width, height) in inches. Only used when subplots=True. + If None, a default size will be determined based on the number of subplots. + + perfect_model_kwargs : dict + Keyword arguments for the perfect model line. + + xlabel : str + Label for the x-axis. + + ylabel : str + Label for the y-axis. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + Returns ------- scatter : list of matplotlib Artist @@ -214,9 +440,74 @@ def _plot_single_estimator( else: # data_source == "X_y" scatter_label = "Data set" + if subplots: + # Calculate grid dimensions + if nrows is None and ncols is None: + ncols = 1 + nrows = 1 + elif nrows is None: + nrows = 1 + elif ncols is None: + ncols = 1 + + # Create figure and subplots + self.figure_ = plt.figure(figsize=figsize) + ax = self.figure_.add_subplot(nrows, ncols, 1) + self.ax_ = ax + + # Use the provided axis or self.ax_ if available, + # otherwise create a new one + plot_ax = ax if ax is not None else self.ax_ + if plot_ax is None: + _, plot_ax = plt.subplots() + + if kind == "actual_vs_predicted": + # For actual vs predicted, using range for both axes + min_value = min(self.range_y_pred.min, self.range_y_true.min) + max_value = max(self.range_y_pred.max, self.range_y_true.max) + x_range_perfect_pred = [min_value, max_value] + y_range_perfect_pred = [min_value, max_value] + + self.line_ = plot_ax.plot( + x_range_perfect_pred, + y_range_perfect_pred, + **perfect_model_kwargs, + )[0] + plot_ax.set( + aspect="equal", + xlim=x_range_perfect_pred, + ylim=y_range_perfect_pred, + xticks=np.linspace( + x_range_perfect_pred[0], x_range_perfect_pred[1], num=5 + ), + yticks=np.linspace( + y_range_perfect_pred[0], y_range_perfect_pred[1], num=5 + ), + ) + + else: # kind == "residual_vs_predicted" + x_range_perfect_pred = [self.range_y_pred.min, self.range_y_pred.max] + y_range_perfect_pred = [self.range_residuals.min, self.range_residuals.max] + + self.line_ = plot_ax.plot( + x_range_perfect_pred, [0, 0], **perfect_model_kwargs + )[0] + plot_ax.set( + xlim=x_range_perfect_pred, + ylim=y_range_perfect_pred, + xticks=np.linspace( + x_range_perfect_pred[0], x_range_perfect_pred[1], num=5 + ), + yticks=np.linspace( + y_range_perfect_pred[0], y_range_perfect_pred[1], num=5 + ), + ) + + plot_ax.set(xlabel=xlabel, ylabel=ylabel) + if kind == "actual_vs_predicted": scatter.append( - self.ax_.scatter( + plot_ax.scatter( y_pred, y_true, label=scatter_label, @@ -225,7 +516,7 @@ def _plot_single_estimator( ) else: # kind == "residual_vs_predicted" scatter.append( - self.ax_.scatter( + plot_ax.scatter( y_pred, residuals, label=scatter_label, @@ -233,7 +524,13 @@ def _plot_single_estimator( ) ) - self.ax_.legend(bbox_to_anchor=(1.02, 1), title=estimator_name) + plot_ax.legend(bbox_to_anchor=(1.02, 1), title=estimator_name) + plot_ax.set_title(f"Model: {estimator_name}") + + if despine: + x_range = plot_ax.get_xlim() + y_range = plot_ax.get_ylim() + _despine_matplotlib_axis(plot_ax, x_range=x_range, y_range=y_range) return scatter @@ -243,6 +540,15 @@ def _plot_cross_validated_estimator( kind: Literal["actual_vs_predicted", "residual_vs_predicted"], estimator_name: str, samples_kwargs: list[dict[str, Any]], + ax: Optional[Axes] = None, + subplots: bool = False, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + figsize: Optional[tuple[float, float]] = None, + perfect_model_kwargs: dict[str, Any], + xlabel: str, + ylabel: str, + despine: bool = True, ) -> list[Artist]: """Plot the prediction error for a cross-validated estimator. @@ -257,6 +563,37 @@ def _plot_cross_validated_estimator( samples_kwargs : list of dict Keyword arguments for the scatter plot. + ax : matplotlib Axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + subplots : bool, default=False + If True, plot each estimator or fold on a separate subplot. + + nrows : int, default=None + Number of rows in the subplot grid. Only used when subplots=True. + If None, it will be computed based on ncols. + + ncols : int, default=None + Number of columns in the subplot grid. Only used when subplots=True. + If None, defaults to 2 for multiple plots, 1 for a single plot. + + figsize : tuple of float, default=None + Figure size (width, height) in inches. Only used when subplots=True. + If None, a default size will be determined based on the number of subplots. + + perfect_model_kwargs : dict + Keyword arguments for the perfect model line. + + xlabel : str + Label for the x-axis. + + ylabel : str + Label for the y-axis. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + Returns ------- scatter : list of matplotlib Artist @@ -269,6 +606,175 @@ def _plot_cross_validated_estimator( len(self.y_true) if len(self.y_true) > 10 else 10, ) + if subplots: + num_plots = len(self.y_true) + + # Calculate grid dimensions + if nrows is None and ncols is None: + if num_plots == 1: + ncols = 1 + nrows = 1 + else: + ncols = min(2, num_plots) + nrows = (num_plots + ncols - 1) // ncols + elif nrows is None: + nrows = (num_plots + (ncols or 1) - 1) // (ncols or 1) + elif ncols is None: + ncols = (num_plots + (nrows or 1) - 1) // (nrows or 1) + + # Create figure and subplots + self.figure_ = plt.figure(figsize=figsize) + axes: list[Axes] = [] + for i in range(num_plots): + if i == 0: + ax = self.figure_.add_subplot(nrows, ncols, i + 1) + else: + ax = self.figure_.add_subplot( + nrows, ncols, i + 1, sharex=axes[0], sharey=axes[0] + ) + axes.append(ax) + + # Plot each fold in its own subplot + for idx, axi in enumerate(axes): + if kind == "actual_vs_predicted": + # For actual vs predicted, we want the same range for both axes + min_value = min(self.range_y_pred.min, self.range_y_true.min) + max_value = max(self.range_y_pred.max, self.range_y_true.max) + x_range_perfect_pred = [min_value, max_value] + y_range_perfect_pred = [min_value, max_value] + + _ = axi.plot( + x_range_perfect_pred, + y_range_perfect_pred, + **perfect_model_kwargs, + )[0] + axi.set( + aspect="equal", + xlim=x_range_perfect_pred, + ylim=y_range_perfect_pred, + xticks=np.linspace( + x_range_perfect_pred[0], x_range_perfect_pred[1], num=5 + ), + yticks=np.linspace( + y_range_perfect_pred[0], y_range_perfect_pred[1], num=5 + ), + ) + + else: # kind == "residual_vs_predicted" + x_range_perfect_pred = [ + self.range_y_pred.min, + self.range_y_pred.max, + ] + y_range_perfect_pred = [ + self.range_residuals.min, + self.range_residuals.max, + ] + + _ = axi.plot(x_range_perfect_pred, [0, 0], **perfect_model_kwargs)[ + 0 + ] + axi.set( + xlim=x_range_perfect_pred, + ylim=y_range_perfect_pred, + xticks=np.linspace( + x_range_perfect_pred[0], x_range_perfect_pred[1], num=5 + ), + yticks=np.linspace( + y_range_perfect_pred[0], y_range_perfect_pred[1], num=5 + ), + ) + + axi.set(xlabel=xlabel, ylabel=ylabel) + + data_points_kwargs_fold = { + "color": colors_markers[idx], + **data_points_kwargs, + } + + data_points_kwargs_validated = _validate_style_kwargs( + data_points_kwargs_fold, samples_kwargs[idx] + ) + + label = f"Estimator of fold #{idx + 1}" + + if kind == "actual_vs_predicted": + scatter.append( + axi.scatter( + self.y_pred[idx], + self.y_true[idx], + label=label, + **data_points_kwargs_validated, + ) + ) + else: # kind == "residual_vs_predicted" + scatter.append( + axi.scatter( + self.y_pred[idx], + self.residuals[idx], + label=label, + **data_points_kwargs_validated, + ) + ) + + axi.set_title(f"Fold #{idx + 1}") + + if despine: + x_range = axi.get_xlim() + y_range = axi.get_ylim() + _despine_matplotlib_axis(axi, x_range=x_range, y_range=y_range) + + # Set the first axis as the main axis for backward compatibility + self.ax_ = axes[0] if axes else None + return scatter + + # Single plot case + plot_ax = ax if ax is not None else self.ax_ + if plot_ax is None: + _, plot_ax = plt.subplots() + + if kind == "actual_vs_predicted": + min_value = min(self.range_y_pred.min, self.range_y_true.min) + max_value = max(self.range_y_pred.max, self.range_y_true.max) + x_range_perfect_pred = [min_value, max_value] + y_range_perfect_pred = [min_value, max_value] + + self.line_ = plot_ax.plot( + x_range_perfect_pred, + y_range_perfect_pred, + **perfect_model_kwargs, + )[0] + plot_ax.set( + aspect="equal", + xlim=x_range_perfect_pred, + ylim=y_range_perfect_pred, + xticks=np.linspace( + x_range_perfect_pred[0], x_range_perfect_pred[1], num=5 + ), + yticks=np.linspace( + y_range_perfect_pred[0], y_range_perfect_pred[1], num=5 + ), + ) + + else: # kind == "residual_vs_predicted" + x_range_perfect_pred = [self.range_y_pred.min, self.range_y_pred.max] + y_range_perfect_pred = [self.range_residuals.min, self.range_residuals.max] + + self.line_ = plot_ax.plot( + x_range_perfect_pred, [0, 0], **perfect_model_kwargs + )[0] + plot_ax.set( + xlim=x_range_perfect_pred, + ylim=y_range_perfect_pred, + xticks=np.linspace( + x_range_perfect_pred[0], x_range_perfect_pred[1], num=5 + ), + yticks=np.linspace( + y_range_perfect_pred[0], y_range_perfect_pred[1], num=5 + ), + ) + + plot_ax.set(xlabel=xlabel, ylabel=ylabel) + for split_idx in range(len(self.y_true)): data_points_kwargs_fold = { "color": colors_markers[split_idx], @@ -283,7 +789,7 @@ def _plot_cross_validated_estimator( if kind == "actual_vs_predicted": scatter.append( - self.ax_.scatter( + plot_ax.scatter( self.y_pred[split_idx], self.y_true[split_idx], label=label, @@ -292,7 +798,7 @@ def _plot_cross_validated_estimator( ) else: # kind == "residual_vs_predicted" scatter.append( - self.ax_.scatter( + plot_ax.scatter( self.y_pred[split_idx], self.residuals[split_idx], label=label, @@ -304,7 +810,12 @@ def _plot_cross_validated_estimator( title = f"{estimator_name} on $\\bf{{{self.data_source}}}$ set" else: title = f"{estimator_name} on $\\bf{{external}}$ set" - self.ax_.legend(bbox_to_anchor=(1.02, 1), title=title) + plot_ax.legend(bbox_to_anchor=(1.02, 1), title=title) + + if despine: + x_range = plot_ax.get_xlim() + y_range = plot_ax.get_ylim() + _despine_matplotlib_axis(plot_ax, x_range=x_range, y_range=y_range) return scatter @@ -314,6 +825,15 @@ def _plot_comparison_estimator( kind: Literal["actual_vs_predicted", "residual_vs_predicted"], estimator_names: list[str], samples_kwargs: list[dict[str, Any]], + ax: Optional[Axes] = None, + subplots: bool = False, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + figsize: Optional[tuple[float, float]] = None, + perfect_model_kwargs: dict[str, Any], + xlabel: str, + ylabel: str, + despine: bool = True, ) -> list[Artist]: """Plot the prediction error of several estimators. @@ -328,6 +848,37 @@ def _plot_comparison_estimator( samples_kwargs : list of dict Keyword arguments for the scatter plot. + ax : matplotlib Axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + subplots : bool, default=False + If True, plot each estimator or fold on a separate subplot. + + nrows : int, default=None + Number of rows in the subplot grid. Only used when subplots=True. + If None, it will be computed based on ncols. + + ncols : int, default=None + Number of columns in the subplot grid. Only used when subplots=True. + If None, defaults to 2 for multiple plots, 1 for a single plot. + + figsize : tuple of float, default=None + Figure size (width, height) in inches. Only used when subplots=True. + If None, a default size will be determined based on the number of subplots. + + perfect_model_kwargs : dict + Keyword arguments for the perfect model line. + + xlabel : str + Label for the x-axis. + + ylabel : str + Label for the y-axis. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + Returns ------- scatter : list of matplotlib Artist @@ -340,129 +891,131 @@ def _plot_comparison_estimator( len(self.y_true) if len(self.y_true) > 10 else 10, ) - for estimator_idx in range(len(self.y_true)): - data_points_kwargs_fold = { - "color": colors_markers[estimator_idx], - **data_points_kwargs, - } - - data_points_kwargs_validated = _validate_style_kwargs( - data_points_kwargs_fold, samples_kwargs[estimator_idx] - ) - - label = f"{estimator_names[estimator_idx]}" - - if kind == "actual_vs_predicted": - scatter.append( - self.ax_.scatter( - self.y_pred[estimator_idx], - self.y_true[estimator_idx], - label=label, - **data_points_kwargs_validated, + if subplots: + num_plots = len(estimator_names) + + # Calculate grid dimensions + if nrows is None and ncols is None: + if num_plots == 1: + ncols = 1 + nrows = 1 + else: + ncols = min(2, num_plots) + nrows = (num_plots + ncols - 1) // ncols + elif nrows is None: + nrows = (num_plots + (ncols or 1) - 1) // (ncols or 1) + elif ncols is None: + ncols = (num_plots + (nrows or 1) - 1) // (nrows or 1) + + # Create figure and subplots + self.figure_ = plt.figure(figsize=figsize) + axes: list[Axes] = [] + for i in range(num_plots): + if i == 0: + ax = self.figure_.add_subplot(nrows, ncols, i + 1) + else: + ax = self.figure_.add_subplot( + nrows, ncols, i + 1, sharex=axes[0], sharey=axes[0] ) - ) - else: # kind == "residual_vs_predicted" - scatter.append( - self.ax_.scatter( - self.y_pred[estimator_idx], - self.residuals[estimator_idx], - label=label, - **data_points_kwargs_validated, + axes.append(ax) + + # Plot each estimator in its own subplot + for idx, (est_name, axi) in enumerate(zip(estimator_names, axes)): + if kind == "actual_vs_predicted": + # For actual vs predicted, we want the same range for both axes + min_value = min(self.range_y_pred.min, self.range_y_true.min) + max_value = max(self.range_y_pred.max, self.range_y_true.max) + x_range_perfect_pred = [min_value, max_value] + y_range_perfect_pred = [min_value, max_value] + + _ = axi.plot( + x_range_perfect_pred, + y_range_perfect_pred, + **perfect_model_kwargs, + )[0] + axi.set( + aspect="equal", + xlim=x_range_perfect_pred, + ylim=y_range_perfect_pred, + xticks=np.linspace( + x_range_perfect_pred[0], x_range_perfect_pred[1], num=5 + ), + yticks=np.linspace( + y_range_perfect_pred[0], y_range_perfect_pred[1], num=5 + ), ) - ) - - self.ax_.legend( - bbox_to_anchor=(1.02, 1), - title=f"Prediction errors on $\\bf{{{self.data_source}}}$ set", - ) - - return scatter - - @StyleDisplayMixin.style_plot - def plot( - self, - ax: Optional[Axes] = None, - *, - estimator_name: Optional[str] = None, - kind: Literal[ - "actual_vs_predicted", "residual_vs_predicted" - ] = "residual_vs_predicted", - data_points_kwargs: Optional[ - Union[dict[str, Any], list[dict[str, Any]]] - ] = None, - perfect_model_kwargs: Optional[dict[str, Any]] = None, - despine: bool = True, - ) -> None: - """Plot visualization. - - Extra keyword arguments will be passed to matplotlib's ``plot``. - Parameters - ---------- - ax : matplotlib axes, default=None - Axes object to plot on. If `None`, a new figure and axes is - created. + else: # kind == "residual_vs_predicted" + x_range_perfect_pred = [ + self.range_y_pred.min, + self.range_y_pred.max, + ] + y_range_perfect_pred = [ + self.range_residuals.min, + self.range_residuals.max, + ] + + _ = axi.plot(x_range_perfect_pred, [0, 0], **perfect_model_kwargs)[ + 0 + ] + axi.set( + xlim=x_range_perfect_pred, + ylim=y_range_perfect_pred, + xticks=np.linspace( + x_range_perfect_pred[0], x_range_perfect_pred[1], num=5 + ), + yticks=np.linspace( + y_range_perfect_pred[0], y_range_perfect_pred[1], num=5 + ), + ) - estimator_name : str - Name of the estimator used to plot the prediction error. If `None`, - we used the inferred name from the estimator. + axi.set(xlabel=xlabel, ylabel=ylabel) - kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ - default="residual_vs_predicted" - The type of plot to draw: + data_points_kwargs_fold = { + "color": colors_markers[idx], + **data_points_kwargs, + } - - "actual_vs_predicted" draws the observed values (y-axis) vs. - the predicted values (x-axis). - - "residual_vs_predicted" draws the residuals, i.e. difference - between observed and predicted values, (y-axis) vs. the predicted - values (x-axis). + data_points_kwargs_validated = _validate_style_kwargs( + data_points_kwargs_fold, samples_kwargs[idx] + ) - data_points_kwargs : dict, default=None - Dictionary with keywords passed to the `matplotlib.pyplot.scatter` - call. + label = f"{est_name}" - perfect_model_kwargs : dict, default=None - Dictionary with keyword passed to the `matplotlib.pyplot.plot` - call to draw the optimal line. + if kind == "actual_vs_predicted": + scatter.append( + axi.scatter( + self.y_pred[idx], + self.y_true[idx], + label=label, + **data_points_kwargs_validated, + ) + ) + else: # kind == "residual_vs_predicted" + scatter.append( + axi.scatter( + self.y_pred[idx], + self.residuals[idx], + label=label, + **data_points_kwargs_validated, + ) + ) - despine : bool, default=True - Whether to remove the top and right spines from the plot. + axi.set_title(f"Model: {est_name}") - Examples - -------- - >>> from sklearn.datasets import load_diabetes - >>> from sklearn.linear_model import Ridge - >>> from skore import train_test_split - >>> from skore import EstimatorReport - >>> X, y = load_diabetes(return_X_y=True) - >>> split_data = train_test_split(X=X, y=y, random_state=0, as_dict=True) - >>> classifier = Ridge() - >>> report = EstimatorReport(classifier, **split_data) - >>> display = report.metrics.prediction_error() - >>> display.plot(kind="actual_vs_predicted") - """ - expected_kind = ("actual_vs_predicted", "residual_vs_predicted") - if kind not in expected_kind: - raise ValueError( - f"`kind` must be one of {', '.join(expected_kind)}. " - f"Got {kind!r} instead." - ) - if kind == "actual_vs_predicted": - xlabel, ylabel = "Predicted values", "Actual values" - else: # kind == "residual_vs_predicted" - xlabel, ylabel = "Predicted values", "Residuals (actual - predicted)" + if despine: + x_range = axi.get_xlim() + y_range = axi.get_ylim() + _despine_matplotlib_axis(axi, x_range=x_range, y_range=y_range) - self.figure_, self.ax_ = (ax.figure, ax) if ax is not None else plt.subplots() + # Set the first axis as the main axis for backward compatibility + self.ax_ = axes[0] if axes else None + return scatter - perfect_model_kwargs_validated = _validate_style_kwargs( - { - "color": "black", - "alpha": 0.7, - "linestyle": "--", - "label": "Perfect predictions", - }, - perfect_model_kwargs or self._default_perfect_model_kwargs or {}, - ) + # Single plot case + plot_ax = ax if ax is not None else self.ax_ + if plot_ax is None: + _, plot_ax = plt.subplots() if kind == "actual_vs_predicted": # For actual vs predicted, we want the same range for both axes @@ -471,12 +1024,12 @@ def plot( x_range_perfect_pred = [min_value, max_value] y_range_perfect_pred = [min_value, max_value] - self.line_ = self.ax_.plot( + self.line_ = plot_ax.plot( x_range_perfect_pred, y_range_perfect_pred, - **perfect_model_kwargs_validated, + **perfect_model_kwargs, )[0] - self.ax_.set( + plot_ax.set( aspect="equal", xlim=x_range_perfect_pred, ylim=y_range_perfect_pred, @@ -492,10 +1045,10 @@ def plot( x_range_perfect_pred = [self.range_y_pred.min, self.range_y_pred.max] y_range_perfect_pred = [self.range_residuals.min, self.range_residuals.max] - self.line_ = self.ax_.plot( - x_range_perfect_pred, [0, 0], **perfect_model_kwargs_validated + self.line_ = plot_ax.plot( + x_range_perfect_pred, [0, 0], **perfect_model_kwargs )[0] - self.ax_.set( + plot_ax.set( xlim=x_range_perfect_pred, ylim=y_range_perfect_pred, xticks=np.linspace( @@ -506,52 +1059,50 @@ def plot( ), ) - self.ax_.set(xlabel=xlabel, ylabel=ylabel) + plot_ax.set(xlabel=xlabel, ylabel=ylabel) - # make the scatter plot afterwards since it should take into account the line - # for the perfect predictions - if data_points_kwargs is None: - data_points_kwargs = self._default_data_points_kwargs - data_points_kwargs = self._validate_data_points_kwargs( - data_points_kwargs=data_points_kwargs - ) + for estimator_idx in range(len(self.y_true)): + data_points_kwargs_fold = { + "color": colors_markers[estimator_idx], + **data_points_kwargs, + } - if self.report_type == "estimator": - self.scatter_ = self._plot_single_estimator( - kind=kind, - estimator_name=( - self.estimator_names[0] - if estimator_name is None - else estimator_name - ), - samples_kwargs=data_points_kwargs, - ) - elif self.report_type == "cross-validation": - self.scatter_ = self._plot_cross_validated_estimator( - kind=kind, - estimator_name=( - self.estimator_names[0] - if estimator_name is None - else estimator_name - ), - samples_kwargs=data_points_kwargs, - ) - elif self.report_type == "comparison-estimator": - self.scatter_ = self._plot_comparison_estimator( - kind=kind, - estimator_names=self.estimator_names, - samples_kwargs=data_points_kwargs, - ) - else: - raise ValueError( - f"`report_type` should be one of 'estimator', 'cross-validation', " - f"or 'comparison-estimator'. Got '{self.report_type}' instead." + data_points_kwargs_validated = _validate_style_kwargs( + data_points_kwargs_fold, samples_kwargs[estimator_idx] ) + label = f"{estimator_names[estimator_idx]}" + + if kind == "actual_vs_predicted": + scatter.append( + plot_ax.scatter( + self.y_pred[estimator_idx], + self.y_true[estimator_idx], + label=label, + **data_points_kwargs_validated, + ) + ) + else: # kind == "residual_vs_predicted" + scatter.append( + plot_ax.scatter( + self.y_pred[estimator_idx], + self.residuals[estimator_idx], + label=label, + **data_points_kwargs_validated, + ) + ) + + plot_ax.legend( + bbox_to_anchor=(1.02, 1), + title=f"Prediction errors on $\\bf{{{self.data_source}}}$ set", + ) + if despine: - x_range = self.ax_.get_xlim() - y_range = self.ax_.get_ylim() - _despine_matplotlib_axis(self.ax_, x_range=x_range, y_range=y_range) + x_range = plot_ax.get_xlim() + y_range = plot_ax.get_ylim() + _despine_matplotlib_axis(plot_ax, x_range=x_range, y_range=y_range) + + return scatter @classmethod def _compute_data_for_display( diff --git a/skore/src/skore/sklearn/_plot/metrics/roc_curve.py b/skore/src/skore/sklearn/_plot/metrics/roc_curve.py index c776d7d015..58a5b69b72 100644 --- a/skore/src/skore/sklearn/_plot/metrics/roc_curve.py +++ b/skore/src/skore/sklearn/_plot/metrics/roc_curve.py @@ -181,6 +181,11 @@ def _plot_single_estimator( roc_curve_kwargs: list[dict[str, Any]], plot_chance_level: bool = True, chance_level_kwargs: Optional[dict[str, Any]] = None, + ax: Optional[Axes] = None, + subplots: bool = False, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + figsize: Optional[tuple[float, float]] = None, ) -> tuple[Axes, list[Line2D], Union[str, None]]: """Plot ROC curve for a single estimator. @@ -201,6 +206,24 @@ def _plot_single_estimator( Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. + ax : matplotlib.axes.Axes, default=None + The axes to plot on. If None, self.ax_ is used. + + subplots : bool, default=False + If True, plot each class on a separate subplot. + + nrows : int, default=None + Number of rows in the subplot grid. Only used when subplots=True. + If None, it will be computed based on ncols. + + ncols : int, default=None + Number of columns in the subplot grid. Only used when subplots=True. + If None, defaults to 2 for multiple plots, 1 for a single plot. + + figsize : tuple of float, default=None + Figure size (width, height) in inches. Only used when subplots=True. + If None, a default size will be determined based on the number of subplots. + Returns ------- ax : matplotlib.axes.Axes @@ -216,6 +239,75 @@ def _plot_single_estimator( lines: list[Line2D] = [] line_kwargs: dict[str, Any] = {} + if ax is not None and subplots: + raise ValueError( + "Cannot specify both 'ax' and 'subplots=True'. " + "Either provide an axes object or use subplots, but not both." + ) + + # Handle subplot creation for multiclass case + if subplots and self.ml_task == "multiclass-classification": + num_plots = len(self.fpr) + + # Calculate grid dimensions + if nrows is None and ncols is None: + if num_plots == 1: + ncols = 1 + nrows = 1 + else: + ncols = min(2, num_plots) + nrows = (num_plots + ncols - 1) // ncols + elif nrows is None: + nrows = (num_plots + (ncols or 1) - 1) // (ncols or 1) + elif ncols is None: + ncols = (num_plots + (nrows or 1) - 1) // (nrows or 1) + + # Create figure and subplots + self.figure_ = plt.figure(figsize=figsize) + axes: list[Axes] = [] + for i in range(num_plots): + if i == 0: + ax = self.figure_.add_subplot(nrows, ncols, i + 1) + else: + ax = self.figure_.add_subplot( + nrows, ncols, i + 1, sharex=axes[0], sharey=axes[0] + ) + axes.append(ax) + + # Plot each class in its own subplot + for idx, (class_label, axi) in enumerate(zip(self.fpr.keys(), axes)): + fpr_class = self.fpr[class_label][0] + tpr_class = self.tpr[class_label][0] + roc_auc_class = self.roc_auc[class_label][0] + roc_curve_kwargs_class = roc_curve_kwargs[idx] + + line_kwargs_validated = _validate_style_kwargs( + {"label": f"AUC = {roc_auc_class:0.2f}"}, roc_curve_kwargs_class + ) + + (line,) = axi.plot(fpr_class, tpr_class, **line_kwargs_validated) + lines.append(line) + + if plot_chance_level: + _add_chance_level( + axi, + chance_level_kwargs, + self._default_chance_level_kwargs, + ) + + axi.set_title(f"Class: {class_label}") + _set_axis_labels(axi, None) + _despine_matplotlib_axis(axi) + + # Set the first axis as the main axis for backward compatibility + self.ax_ = axes[0] if axes else None + return self.ax_, lines, None + + # Single plot case (binary or multiclass) + plot_ax = ax if ax is not None else self.ax_ + if plot_ax is None: + _, plot_ax = plt.subplots() + if self.ml_task == "binary-classification": pos_label = cast(PositiveLabel, self.pos_label) if self.data_source in ("train", "test"): @@ -230,7 +322,7 @@ def _plot_single_estimator( line_kwargs, roc_curve_kwargs[0] ) - (line,) = self.ax_.plot( + (line,) = plot_ax.plot( self.fpr[pos_label][0], self.tpr[pos_label][0], **line_kwargs_validated ) lines.append(line) @@ -265,23 +357,23 @@ def _plot_single_estimator( default_line_kwargs, roc_curve_kwargs_class ) - (line,) = self.ax_.plot(fpr_class, tpr_class, **line_kwargs) + (line,) = plot_ax.plot(fpr_class, tpr_class, **line_kwargs) lines.append(line) info_pos_label = None # irrelevant for multiclass if plot_chance_level: self.chance_level_ = _add_chance_level( - self.ax_, + plot_ax, chance_level_kwargs, self._default_chance_level_kwargs, ) else: self.chance_level_ = None - self.ax_.legend(bbox_to_anchor=(1.02, 1), title=estimator_name) + plot_ax.legend(bbox_to_anchor=(1.02, 1), title=estimator_name) - return self.ax_, lines, info_pos_label + return plot_ax, lines, info_pos_label def _plot_cross_validated_estimator( self, @@ -290,6 +382,11 @@ def _plot_cross_validated_estimator( roc_curve_kwargs: list[dict[str, Any]], plot_chance_level: bool = True, chance_level_kwargs: Optional[dict[str, Any]] = None, + ax: Optional[Axes] = None, + subplots: bool = False, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + figsize: Optional[tuple[float, float]] = None, ) -> tuple[Axes, list[Line2D], Union[str, None]]: """Plot ROC curve for a cross-validated estimator. @@ -309,6 +406,24 @@ def _plot_cross_validated_estimator( Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. + ax : matplotlib.axes.Axes, default=None + The axes to plot on. If None, self.ax_ is used. + + subplots : bool, default=False + If True, plot each fold or class on a separate subplot. + + nrows : int, default=None + Number of rows in the subplot grid. Only used when subplots=True. + If None, it will be computed based on ncols. + + ncols : int, default=None + Number of columns in the subplot grid. Only used when subplots=True. + If None, defaults to 2 for multiple plots, 1 for a single plot. + + figsize : tuple of float, default=None + Figure size (width, height) in inches. Only used when subplots=True. + If None, a default size will be determined based on the number of subplots. + Returns ------- ax : matplotlib.axes.Axes @@ -324,6 +439,129 @@ def _plot_cross_validated_estimator( lines: list[Line2D] = [] line_kwargs: dict[str, Any] = {} + if ax is not None and subplots: + raise ValueError( + "Cannot specify both 'ax' and 'subplots=True'. " + "Either provide an axes object or use subplots, but not both." + ) + + if subplots: + if self.ml_task == "binary-classification": + pos_label = cast(PositiveLabel, self.pos_label) + num_plots = len(self.fpr[pos_label]) + else: # multiclass + num_plots = len(self.fpr) + + # Calculate grid dimensions + if nrows is None and ncols is None: + if num_plots == 1: + ncols = 1 + nrows = 1 + else: + ncols = min(2, num_plots) + nrows = (num_plots + ncols - 1) // ncols + elif nrows is None: + nrows = (num_plots + (ncols or 1) - 1) // (ncols or 1) + elif ncols is None: + ncols = (num_plots + (nrows or 1) - 1) // (nrows or 1) + + # Create figure and subplots + self.figure_ = plt.figure(figsize=figsize) + axes: list[Axes] = [] + for i in range(num_plots): + if i == 0: + ax = self.figure_.add_subplot(nrows, ncols, i + 1) + else: + ax = self.figure_.add_subplot( + nrows, ncols, i + 1, sharex=axes[0], sharey=axes[0] + ) + axes.append(ax) + + # Plot in subplots + if self.ml_task == "binary-classification": + pos_label = cast(PositiveLabel, self.pos_label) + for idx, axi in enumerate(axes): + fpr_split = self.fpr[pos_label][idx] + tpr_split = self.tpr[pos_label][idx] + roc_auc_split = self.roc_auc[pos_label][idx] + + line_kwargs_validated = _validate_style_kwargs( + line_kwargs, roc_curve_kwargs[idx] + ) + line_kwargs_validated["label"] = ( + f"Estimator of fold #{idx + 1} (AUC = {roc_auc_split:0.2f})" + ) + + (line,) = axi.plot(fpr_split, tpr_split, **line_kwargs_validated) + lines.append(line) + + if plot_chance_level: + _add_chance_level( + axi, + chance_level_kwargs, + self._default_chance_level_kwargs, + ) + + axi.set_title(f"Fold #{idx + 1}") + _set_axis_labels(axi, f"\n(Positive label: {pos_label})") + _despine_matplotlib_axis(axi) + + else: # multiclass + class_colors = sample_mpl_colormap( + colormaps.get_cmap("tab10"), + 10 if len(self.fpr) < 10 else len(self.fpr), + ) + + for idx, (class_label, axi) in enumerate(zip(self.fpr.keys(), axes)): + fpr_class = self.fpr[class_label] + tpr_class = self.tpr[class_label] + roc_auc_class = self.roc_auc[class_label] + roc_auc_mean = np.mean(roc_auc_class) + roc_auc_std = np.std(roc_auc_class) + + for split_idx in range(len(fpr_class)): + fpr_split = fpr_class[split_idx] + tpr_split = tpr_class[split_idx] + + line_kwargs_validated = _validate_style_kwargs( + { + "color": class_colors[idx], + "alpha": 0.3, + }, + roc_curve_kwargs[idx], + ) + if split_idx == 0: + line_kwargs_validated["label"] = ( + f"AUC = {roc_auc_mean:0.2f} +/- {roc_auc_std:0.2f}" + ) + else: + line_kwargs_validated["label"] = None + + (line,) = axi.plot( + fpr_split, tpr_split, **line_kwargs_validated + ) + lines.append(line) + + if plot_chance_level: + _add_chance_level( + axi, + chance_level_kwargs, + self._default_chance_level_kwargs, + ) + + axi.set_title(f"Class: {class_label}") + _set_axis_labels(axi, None) + _despine_matplotlib_axis(axi) + + # Set the first axis as the main axis for backward compatibility + self.ax_ = axes[0] if axes else None + return self.ax_, lines, None + + # Single plot case + plot_ax = ax if ax is not None else self.ax_ + if plot_ax is None: + _, plot_ax = plt.subplots() + if self.ml_task == "binary-classification": pos_label = cast(PositiveLabel, self.pos_label) for split_idx in range(len(self.fpr[pos_label])): @@ -338,7 +576,7 @@ def _plot_cross_validated_estimator( f"Estimator of fold #{split_idx + 1} (AUC = {roc_auc_split:0.2f})" ) - (line,) = self.ax_.plot(fpr_split, tpr_split, **line_kwargs_validated) + (line,) = plot_ax.plot(fpr_split, tpr_split, **line_kwargs_validated) lines.append(line) info_pos_label = ( @@ -378,14 +616,14 @@ def _plot_cross_validated_estimator( else: line_kwargs_validated["label"] = None - (line,) = self.ax_.plot( + (line,) = plot_ax.plot( fpr_split, tpr_split, **line_kwargs_validated ) lines.append(line) if plot_chance_level: self.chance_level_ = _add_chance_level( - self.ax_, + plot_ax, chance_level_kwargs, self._default_chance_level_kwargs, ) @@ -396,9 +634,9 @@ def _plot_cross_validated_estimator( title = f"{estimator_name} on $\\bf{{{self.data_source}}}$ set" else: title = f"{estimator_name} on $\\bf{{external}}$ set" - self.ax_.legend(bbox_to_anchor=(1.02, 1), title=title) + plot_ax.legend(bbox_to_anchor=(1.02, 1), title=title) - return self.ax_, lines, info_pos_label + return plot_ax, lines, info_pos_label def _plot_comparison_estimator( self, @@ -407,6 +645,11 @@ def _plot_comparison_estimator( roc_curve_kwargs: list[dict[str, Any]], plot_chance_level: bool = True, chance_level_kwargs: Optional[dict[str, Any]] = None, + ax: Optional[Axes] = None, + subplots: bool = False, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + figsize: Optional[tuple[float, float]] = None, ) -> tuple[Axes, list[Line2D], Union[str, None]]: """Plot ROC curve of several estimators. @@ -426,6 +669,24 @@ def _plot_comparison_estimator( Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. + ax : matplotlib.axes.Axes, default=None + The axes to plot on. If None, self.ax_ is used. + + subplots : bool, default=False + If True, plot each estimator on a separate subplot. + + nrows : int, default=None + Number of rows in the subplot grid. Only used when subplots=True. + If None, it will be computed based on ncols. + + ncols : int, default=None + Number of columns in the subplot grid. Only used when subplots=True. + If None, defaults to 2 for multiple plots, 1 for a single plot. + + figsize : tuple of float, default=None + Figure size (width, height) in inches. Only used when subplots=True. + If None, a default size will be determined based on the number of subplots. + Returns ------- ax : matplotlib.axes.Axes @@ -441,6 +702,115 @@ def _plot_comparison_estimator( lines: list[Line2D] = [] line_kwargs: dict[str, Any] = {} + if ax is not None and subplots: + raise ValueError( + "Cannot specify both 'ax' and 'subplots=True'. " + "Either provide an axes object or use subplots, but not both." + ) + + if subplots: + num_plots = len(estimator_names) + + # Calculate grid dimensions + if nrows is None and ncols is None: + if num_plots == 1: + ncols = 1 + nrows = 1 + else: + ncols = min(2, num_plots) + nrows = (num_plots + ncols - 1) // ncols + elif nrows is None: + nrows = (num_plots + (ncols or 1) - 1) // (ncols or 1) + elif ncols is None: + ncols = (num_plots + (nrows or 1) - 1) // (nrows or 1) + + # Create figure and subplots + self.figure_ = plt.figure(figsize=figsize) + axes: list[Axes] = [] + for i in range(num_plots): + if i == 0: + ax = self.figure_.add_subplot(nrows, ncols, i + 1) + else: + ax = self.figure_.add_subplot( + nrows, ncols, i + 1, sharex=axes[0], sharey=axes[0] + ) + axes.append(ax) + + # Plot each estimator in its own subplot + for idx, (est_name, axi) in enumerate(zip(estimator_names, axes)): + if self.ml_task == "binary-classification": + pos_label = cast(PositiveLabel, self.pos_label) + fpr_est = self.fpr[pos_label][idx] + tpr_est = self.tpr[pos_label][idx] + roc_auc_est = self.roc_auc[pos_label][idx] + + line_kwargs_validated = _validate_style_kwargs( + line_kwargs, roc_curve_kwargs[idx] + ) + line_kwargs_validated["label"] = f"AUC = {roc_auc_est:0.2f}" + + (line,) = axi.plot(fpr_est, tpr_est, **line_kwargs_validated) + lines.append(line) + + if plot_chance_level: + _add_chance_level( + axi, + chance_level_kwargs, + self._default_chance_level_kwargs, + ) + + axi.set_title(f"Model: {est_name}") + _set_axis_labels(axi, f"\n(Positive label: {pos_label})") + _despine_matplotlib_axis(axi) + + else: # multiclass + class_colors = sample_mpl_colormap( + colormaps.get_cmap("tab10"), + 10 if len(self.fpr) < 10 else len(self.fpr), + ) + + for class_idx, class_ in enumerate(self.fpr): + fpr_est_class = self.fpr[class_][idx] + tpr_est_class = self.tpr[class_][idx] + roc_auc_mean = self.roc_auc[class_][idx] + class_linestyle = LINESTYLE[(class_idx % len(LINESTYLE))][1] + + line_kwargs["color"] = class_colors[class_idx] + line_kwargs["alpha"] = 0.6 + line_kwargs["linestyle"] = class_linestyle + + line_kwargs_validated = _validate_style_kwargs( + line_kwargs, roc_curve_kwargs[idx] + ) + line_kwargs_validated["label"] = ( + f"{str(class_).title()} (AUC = {roc_auc_mean:0.2f})" + ) + + (line,) = axi.plot( + fpr_est_class, tpr_est_class, **line_kwargs_validated + ) + lines.append(line) + + if plot_chance_level: + _add_chance_level( + axi, + chance_level_kwargs, + self._default_chance_level_kwargs, + ) + + axi.set_title(f"Model: {est_name}") + _set_axis_labels(axi, None) + _despine_matplotlib_axis(axi) + + # Set the first axis as the main axis for backward compatibility + self.ax_ = axes[0] if axes else None + return self.ax_, lines, None + + # Single plot case + plot_ax = ax if ax is not None else self.ax_ + if plot_ax is None: + _, plot_ax = plt.subplots() + if self.ml_task == "binary-classification": pos_label = cast(PositiveLabel, self.pos_label) for est_idx, est_name in enumerate(estimator_names): @@ -454,7 +824,7 @@ def _plot_comparison_estimator( line_kwargs_validated["label"] = ( f"{est_name} (AUC = {roc_auc_est:0.2f})" ) - (line,) = self.ax_.plot(fpr_est, tpr_est, **line_kwargs_validated) + (line,) = plot_ax.plot(fpr_est, tpr_est, **line_kwargs_validated) lines.append(line) info_pos_label = ( @@ -487,26 +857,26 @@ def _plot_comparison_estimator( f"(AUC = {roc_auc_mean:0.2f})" ) - (line,) = self.ax_.plot( + (line,) = plot_ax.plot( fpr_est_class, tpr_est_class, **line_kwargs_validated ) lines.append(line) if plot_chance_level: self.chance_level_ = _add_chance_level( - self.ax_, + plot_ax, chance_level_kwargs, self._default_chance_level_kwargs, ) else: self.chance_level_ = None - self.ax_.legend( + plot_ax.legend( bbox_to_anchor=(1.02, 1), title=f"{self.ml_task.title()} on $\\bf{{{self.data_source}}}$ set", ) - return self.ax_, lines, info_pos_label + return plot_ax, lines, info_pos_label @StyleDisplayMixin.style_plot def plot( @@ -518,6 +888,10 @@ def plot( plot_chance_level: bool = True, chance_level_kwargs: Optional[dict[str, Any]] = None, despine: bool = True, + subplots: bool = False, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + figsize: Optional[tuple[float, float]] = None, ) -> None: """Plot visualization. @@ -547,6 +921,21 @@ def plot( despine : bool, default=True Whether to remove the top and right spines from the plot. + subplots : bool, default=False + If True, plot each estimator or fold on a separate subplot. + + nrows : int, default=None + Number of rows in the subplot grid. Only used when subplots=True. + If None, it will be computed based on ncols. + + ncols : int, default=None + Number of columns in the subplot grid. Only used when subplots=True. + If None, defaults to 2 for multiple plots, 1 for a single plot. + + figsize : tuple of float, default=None + Figure size (width, height) in inches. Only used when subplots=True. + If None, a default size will be determined based on the number of subplots. + Examples -------- >>> from sklearn.datasets import load_breast_cancer @@ -559,8 +948,16 @@ def plot( >>> report = EstimatorReport(classifier, **split_data) >>> display = report.metrics.roc() >>> display.plot(roc_curve_kwargs={"color": "tab:red"}) + + With subplots: + + >>> display.plot(subplots=True) """ - self.figure_, self.ax_ = (ax.figure, ax) if ax is not None else plt.subplots() + if ax is not None and subplots: + raise ValueError( + "Cannot specify both 'ax' and 'subplots=True'. " + "Either provide an axes object or use subplots, but not both." + ) if roc_curve_kwargs is None: roc_curve_kwargs = self._default_roc_curve_kwargs @@ -571,6 +968,12 @@ def plot( report_type=self.report_type, ) + # Create figure and axes if not using subplots + if not subplots: + self.figure_, self.ax_ = ( + (ax.figure, ax) if ax is not None else plt.subplots() + ) + if self.report_type == "estimator": self.ax_, self.lines_, info_pos_label = self._plot_single_estimator( estimator_name=( @@ -581,6 +984,11 @@ def plot( roc_curve_kwargs=roc_curve_kwargs, plot_chance_level=plot_chance_level, chance_level_kwargs=chance_level_kwargs, + ax=ax if not subplots else None, + subplots=subplots, + nrows=nrows, + ncols=ncols, + figsize=figsize, ) elif self.report_type == "cross-validation": self.ax_, self.lines_, info_pos_label = ( @@ -593,6 +1001,11 @@ def plot( roc_curve_kwargs=roc_curve_kwargs, plot_chance_level=plot_chance_level, chance_level_kwargs=chance_level_kwargs, + ax=ax if not subplots else None, + subplots=subplots, + nrows=nrows, + ncols=ncols, + figsize=figsize, ) ) elif self.report_type == "comparison-estimator": @@ -601,6 +1014,11 @@ def plot( roc_curve_kwargs=roc_curve_kwargs, plot_chance_level=plot_chance_level, chance_level_kwargs=chance_level_kwargs, + ax=ax if not subplots else None, + subplots=subplots, + nrows=nrows, + ncols=ncols, + figsize=figsize, ) else: raise ValueError( @@ -608,10 +1026,13 @@ def plot( f"or 'comparison-estimator'. Got '{self.report_type}' instead." ) - _set_axis_labels(self.ax_, info_pos_label) + if not subplots: + _set_axis_labels(self.ax_, info_pos_label) + + if despine: + _despine_matplotlib_axis(self.ax_) - if despine: - _despine_matplotlib_axis(self.ax_) + return self.figure_ @classmethod def _compute_data_for_display( diff --git a/skore/src/skore/sklearn/_plot/utils.py b/skore/src/skore/sklearn/_plot/utils.py index 09c2808d6e..45be2bd9fe 100644 --- a/skore/src/skore/sklearn/_plot/utils.py +++ b/skore/src/skore/sklearn/_plot/utils.py @@ -3,6 +3,7 @@ from io import StringIO from typing import Any, Literal, Optional, Union, cast +import matplotlib.pyplot as plt import numpy as np from matplotlib.axes import Axes from matplotlib.colors import Colormap @@ -354,3 +355,56 @@ def sample_mpl_colormap( """ indices = np.linspace(0, 1, n) return [cmap(i) for i in indices] + + +def get_subplots( + nplots: int, + nrows: Optional[int] = None, + ncols: Optional[int] = None, + figsize: Optional[tuple[float, float]] = None, +) -> tuple[plt.Figure, np.ndarray]: + """Create subplots for displaying metrics. + + Parameters + ---------- + nplots : int + Number of plots to create. + nrows : int, default=None + Number of rows in the subplot grid. If None, it will be computed based on ncols. + ncols : int, default=None + Number of columns in the subplot grid. If None, defaults to 2 for + multiple plots, 1 for a single plot. + figsize : tuple of float, default=None + Figure size (width, height) in inches. If None, a default size will be + determined based on the number of subplots. + + Returns + ------- + fig : Figure + The matplotlib Figure object. + axes : ndarray of Axes + Array of Axes objects. When nplots > 1, the array is 2D. + """ + # Default ncols=2 if more than 1 plot, else set to 1 + if nplots <= 1: + return plt.subplots(figsize=figsize or (6, 4), squeeze=True) + + ncols = ncols or 2 + if nrows is None: + quot, resid = divmod(nplots, ncols) + nrows = quot + (1 if resid else 0) + + # Setting a reasonable default figsize if none provided + if figsize is None: + figsize = (6 * ncols, 4 * nrows) + + fig, axes = plt.subplots( + nrows, ncols, figsize=figsize, squeeze=False, constrained_layout=True + ) + + # Hide unused subplots + for i in range(nplots, nrows * ncols): + row_idx, col_idx = divmod(i, ncols) + axes[row_idx, col_idx].set_visible(False) + + return fig, axes diff --git a/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py b/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py index dfeb90c30c..d9b85aa9bd 100644 --- a/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py +++ b/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py @@ -631,3 +631,228 @@ def test_precision_recall_curve_display_wrong_report_type( ) with pytest.raises(ValueError, match=err_msg): display.plot() + + +def test_precision_recall_curve_display_subplots_basic_binary( + pyplot, binary_classification_data +): + """Test that subplots=True creates multiple subplots with default parameters + for binary classification.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + + # Create a comparison report with multiple estimators + est1 = clone(estimator) + est2 = clone(estimator) + est1.fit(X_train, y_train) + est2.fit(X_train, y_train) + + report = ComparisonReport( + reports={ + "estimator 1": EstimatorReport( + est1, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + "estimator 2": EstimatorReport( + est2, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + }, + ) + display = report.metrics.precision_recall() + display.plot(subplots=True) + + assert hasattr(display, "figure_") + + # Check correct number of subplots + axes = display.figure_.get_axes() + assert len(axes) == 2 + + # Check titles were set correctly + assert "Model: estimator 1" in axes[0].get_title() + assert "Model: estimator 2" in axes[1].get_title() + + # Each subplot should have correct labels + for ax in axes: + assert "Recall" in ax.get_xlabel() + assert "Precision" in ax.get_ylabel() + assert ax.get_aspect() in ("equal", 1.0) + + +def test_precision_recall_curve_display_subplots_basic_multiclass( + pyplot, multiclass_classification_data +): + """Test that subplots=True creates multiple subplots with default parameters + for multiclass classification.""" + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + + # In multiclass case, we get one subplot per class + display = report.metrics.precision_recall() + display.plot(subplots=True) + + assert hasattr(display, "figure_") + + # Check correct number of subplots (one per class) + axes = display.figure_.get_axes() + assert len(axes) == len(estimator.classes_) + + # Check titles were set correctly + for i, class_label in enumerate(estimator.classes_): + assert f"Class: {class_label}" in axes[i].get_title() + + # Each subplot should have correct labels + for ax in axes: + assert "Recall" in ax.get_xlabel() + assert "Precision" in ax.get_ylabel() + assert ax.get_aspect() in ("equal", 1.0) + + +def test_precision_recall_curve_display_subplots_cv_binary( + pyplot, binary_classification_data_no_split +): + """Test subplots with cross-validation for binary classification.""" + (estimator, X, y), cv = binary_classification_data_no_split, 3 + report = CrossValidationReport(estimator, X=X, y=y, cv_splitter=cv) + display = report.metrics.precision_recall() + display.plot(subplots=True) + + assert hasattr(display, "figure_") + + # Check number of subplots matches number of CV folds + axes = display.figure_.get_axes() + assert len(axes) == cv + + # Check titles for each fold + for i, ax in enumerate(axes): + assert f"Fold #{i + 1}" in ax.get_title() + + # Each subplot should have correct labels + for ax in axes: + assert "Recall" in ax.get_xlabel() + assert "Precision" in ax.get_ylabel() + assert ax.get_aspect() in ("equal", 1.0) + + +def test_precision_recall_curve_display_subplots_custom_layout( + pyplot, binary_classification_data +): + """Test subplots with custom layout parameters.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + + # Create a comparison report with multiple estimators + est1 = clone(estimator) + est2 = clone(estimator) + est3 = clone(estimator) + est4 = clone(estimator) + est1.fit(X_train, y_train) + est2.fit(X_train, y_train) + est3.fit(X_train, y_train) + est4.fit(X_train, y_train) + + report = ComparisonReport( + reports={ + "estimator 1": EstimatorReport( + est1, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + "estimator 2": EstimatorReport( + est2, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + "estimator 3": EstimatorReport( + est3, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + "estimator 4": EstimatorReport( + est4, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + }, + ) + display = report.metrics.precision_recall() + + # Test with custom nrows and ncols + figsize = (12, 10) + display.plot(subplots=True, nrows=2, ncols=2, figsize=figsize) + + # Check figure was created with correct size + assert hasattr(display, "figure_") + assert display.figure_.get_size_inches()[0] == figsize[0] + assert display.figure_.get_size_inches()[1] == figsize[1] + + # Check layout is correct + axes = display.figure_.get_axes() + assert len(axes) == 4 + + # Check subplot arrangement (2 rows, 2 columns) + pos1 = axes[0].get_position() + pos2 = axes[1].get_position() + pos3 = axes[2].get_position() + pos4 = axes[3].get_position() + + # First row: similar y positions for axes 0 and 1 + assert abs(pos1.y0 - pos2.y0) < 0.1 + # Second row: similar y positions for axes 2 and 3 + assert abs(pos3.y0 - pos4.y0) < 0.1 + # First column: similar x positions for axes 0 and 2 + assert abs(pos1.x0 - pos3.x0) < 0.1 + # Second column: similar x positions for axes 1 and 3 + assert abs(pos2.x0 - pos4.x0) < 0.1 + + +def test_precision_recall_curve_display_ax_and_subplots_error( + pyplot, binary_classification_data +): + """Test that an error is raised when both ax and subplots=True are specified.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.precision_recall() + + # Create a figure and axis to pass + fig, ax = pyplot.subplots() + + # Test that error is raised when both ax and subplots=True are specified + with pytest.raises( + ValueError, match="Cannot specify both 'ax' and 'subplots=True'" + ): + display.plot(ax=ax, subplots=True) + + +def test_precision_recall_curve_display_subplots_estimator_report( + pyplot, binary_classification_data +): + """Test subplots with simple estimator report (should be a single plot).""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.precision_recall() + display.plot(subplots=True) + + # For a single estimator, we should get a single plot + assert hasattr(display, "figure_") + axes = display.figure_.get_axes() + assert len(axes) == 1 + assert "Model: LogisticRegression" in axes[0].get_title() diff --git a/skore/tests/unit/sklearn/plot/test_prediction_error.py b/skore/tests/unit/sklearn/plot/test_prediction_error.py index e5f51b508b..deb8cb3db1 100644 --- a/skore/tests/unit/sklearn/plot/test_prediction_error.py +++ b/skore/tests/unit/sklearn/plot/test_prediction_error.py @@ -577,3 +577,143 @@ def test_prediction_error_display_wrong_report_type(pyplot, regression_data): ) with pytest.raises(ValueError, match=err_msg): display.plot() + + +def test_prediction_error_display_subplots_basic(pyplot, regression_data): + """Test that subplots=True creates multiple subplots with default parameters.""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = ComparisonReport( + reports={ + "estimator 1": EstimatorReport( + estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + "estimator 2": EstimatorReport( + estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + }, + ) + display = report.metrics.prediction_error() + display.plot(subplots=True) + + assert hasattr(display, "figure_") + assert len(display.scatter_) == 2 + + axes = display.figure_.get_axes() + assert len(axes) == 2 + assert "Model: estimator 1" in axes[0].get_title() + assert "Model: estimator 2" in axes[1].get_title() + + +def test_prediction_error_display_subplots_custom_layout(pyplot, regression_data): + """Test subplots with custom layout parameters.""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = ComparisonReport( + reports={ + "estimator 1": EstimatorReport( + estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + "estimator 2": EstimatorReport( + estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + "estimator 3": EstimatorReport( + estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + }, + ) + display = report.metrics.prediction_error() + + figsize = (10, 8) + display.plot(subplots=True, nrows=3, ncols=1, figsize=figsize) + + assert hasattr(display, "figure_") + assert display.figure_.get_size_inches()[0] == figsize[0] + assert display.figure_.get_size_inches()[1] == figsize[1] + + axes = display.figure_.get_axes() + assert len(axes) == 3 + + pos1 = axes[0].get_position() + pos2 = axes[1].get_position() + pos3 = axes[2].get_position() + + # Same column (similar x positions) + assert abs(pos1.x0 - pos2.x0) < 0.1 + assert abs(pos2.x0 - pos3.x0) < 0.1 + + # Different rows (decreasing y positions) + assert pos1.y0 > pos2.y0 + assert pos2.y0 > pos3.y0 + + +def test_prediction_error_display_subplots_cross_validation( + pyplot, regression_data_no_split +): + """Test subplots with cross-validation data.""" + (estimator, X, y), cv = regression_data_no_split, 3 + report = CrossValidationReport(estimator, X=X, y=y, cv_splitter=cv) + display = report.metrics.prediction_error() + display.plot(subplots=True) + + assert hasattr(display, "figure_") + + # Check number of subplots matches number of CV folds + axes = display.figure_.get_axes() + assert len(axes) == cv + + # Check titles for each fold + for i, ax in enumerate(axes): + assert f"Fold #{i + 1}" in ax.get_title() + + +def test_prediction_error_display_ax_and_subplots_error(pyplot, regression_data): + """Test that an error is raised when both ax and subplots=True are specified.""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.prediction_error() + + # Create a figure and axis to pass + fig, ax = pyplot.subplots() + + # Test that error is raised when both ax and subplots=True are specified + with pytest.raises( + ValueError, match="Cannot specify both 'ax' and 'subplots=True'" + ): + display.plot(ax=ax, subplots=True) + + +def test_prediction_error_display_subplots_estimator_report(pyplot, regression_data): + """Test subplots with simple estimator report (should be a single plot).""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.prediction_error() + display.plot(subplots=True) + + # For a single estimator, we should get a single plot + assert hasattr(display, "figure_") + axes = display.figure_.get_axes() + assert len(axes) == 1 + assert "Model: LinearRegression" in axes[0].get_title() diff --git a/skore/tests/unit/sklearn/plot/test_roc_curve.py b/skore/tests/unit/sklearn/plot/test_roc_curve.py index a60ee50a4a..552bfd773b 100644 --- a/skore/tests/unit/sklearn/plot/test_roc_curve.py +++ b/skore/tests/unit/sklearn/plot/test_roc_curve.py @@ -671,3 +671,190 @@ def test_roc_curve_display_wrong_report_type(pyplot, binary_classification_data) ) with pytest.raises(ValueError, match=err_msg): display.plot() + + +def test_roc_curve_display_subplots_basic_binary(pyplot, binary_classification_data): + """Test that subplots=True creates multiple subplots with default parameters + for binary classification.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + + # Create a comparison report with multiple estimators + est1 = clone(estimator) + est2 = clone(estimator) + est1.fit(X_train, y_train) + est2.fit(X_train, y_train) + + report = ComparisonReport( + reports={ + "estimator 1": EstimatorReport( + est1, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + "estimator 2": EstimatorReport( + est2, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + }, + ) + display = report.metrics.roc() + display.plot(subplots=True) + + assert hasattr(display, "figure_") + + axes = display.figure_.get_axes() + assert len(axes) == 2 + + assert "Model: estimator 1" in axes[0].get_title() + assert "Model: estimator 2" in axes[1].get_title() + + # Each subplot should have correct labels + for ax in axes: + assert "False Positive Rate" in ax.get_xlabel() + assert "True Positive Rate" in ax.get_ylabel() + assert ax.get_aspect() in ("equal", 1.0) + + +def test_roc_curve_display_subplots_basic_multiclass( + pyplot, multiclass_classification_data +): + """Test that subplots=True creates multiple subplots with default parameters + for multiclass classification.""" + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + + # In multiclass case, we should get one subplot per class + display = report.metrics.roc() + display.plot(subplots=True) + + assert hasattr(display, "figure_") + + # Check correct number of subplots (one per class) + axes = display.figure_.get_axes() + assert len(axes) == len(estimator.classes_) + + for i, class_label in enumerate(estimator.classes_): + assert f"Class: {class_label}" in axes[i].get_title() + + # Each subplot should have correct labels + for ax in axes: + assert "False Positive Rate" in ax.get_xlabel() + assert "True Positive Rate" in ax.get_ylabel() + assert ax.get_aspect() in ("equal", 1.0) + + +def test_roc_curve_display_subplots_cv_binary( + pyplot, binary_classification_data_no_split +): + """Test subplots with cross-validation for binary classification.""" + (estimator, X, y), cv = binary_classification_data_no_split, 3 + report = CrossValidationReport(estimator, X=X, y=y, cv_splitter=cv) + display = report.metrics.roc() + display.plot(subplots=True) + + assert hasattr(display, "figure_") + + # Check number of subplots matches number of CV folds + axes = display.figure_.get_axes() + assert len(axes) == cv + + # Check titles for each fold + for i, ax in enumerate(axes): + assert f"Fold #{i + 1}" in ax.get_title() + + # Each subplot should have correct labels + for ax in axes: + assert "False Positive Rate" in ax.get_xlabel() + assert "True Positive Rate" in ax.get_ylabel() + assert ax.get_aspect() in ("equal", 1.0) + + +def test_roc_curve_display_subplots_custom_layout(pyplot, binary_classification_data): + """Test subplots with custom layout parameters.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + + # Create a comparison report with multiple estimators + est1 = clone(estimator) + est2 = clone(estimator) + est3 = clone(estimator) + est1.fit(X_train, y_train) + est2.fit(X_train, y_train) + est3.fit(X_train, y_train) + + report = ComparisonReport( + reports={ + "estimator 1": EstimatorReport( + est1, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + "estimator 2": EstimatorReport( + est2, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + "estimator 3": EstimatorReport( + est3, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ), + }, + ) + display = report.metrics.roc() + + # Test with custom nrows and ncols + figsize = (10, 8) + display.plot(subplots=True, nrows=1, ncols=3, figsize=figsize) + + # Check figure was created with correct size + assert hasattr(display, "figure_") + assert display.figure_.get_size_inches()[0] == figsize[0] + assert display.figure_.get_size_inches()[1] == figsize[1] + + # Check layout is correct + axes = display.figure_.get_axes() + assert len(axes) == 3 + + # Check subplot arrangement (1 row, 3 columns) + pos1 = axes[0].get_position() + pos2 = axes[1].get_position() + pos3 = axes[2].get_position() + + # Same row (similar y positions) + assert abs(pos1.y0 - pos2.y0) < 0.1 + assert abs(pos2.y0 - pos3.y0) < 0.1 + + # Different columns (increasing x positions) + assert pos1.x0 < pos2.x0 + assert pos2.x0 < pos3.x0 + + +def test_roc_curve_display_ax_and_subplots_error(pyplot, binary_classification_data): + """Test that an error is raised when both ax and subplots=True are specified.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.roc() + + # Create a figure and axis to pass + fig, ax = pyplot.subplots() + + # Test that error is raised when both ax and subplots=True are specified + with pytest.raises( + ValueError, match="Cannot specify both 'ax' and 'subplots=True'" + ): + display.plot(ax=ax, subplots=True)