diff --git a/skore/src/skore/_sklearn/_plot/base.py b/skore/src/skore/_sklearn/_plot/base.py index 821e1311a1..44cef6856c 100644 --- a/skore/src/skore/_sklearn/_plot/base.py +++ b/skore/src/skore/_sklearn/_plot/base.py @@ -1,11 +1,30 @@ +import inspect +from collections.abc import Callable +from functools import wraps +from io import StringIO from typing import Any, Protocol, runtime_checkable +import matplotlib.pyplot as plt import pandas as pd +from rich.console import Console +from rich.panel import Panel +from rich.tree import Tree + +from skore._config import get_config +from skore._sklearn.types import PlotBackend + +######################################################################################## +# Display protocol +######################################################################################## @runtime_checkable class Display(Protocol): - """Protocol specifying the common API for all `skore` displays.""" + """Protocol specifying the common API for all `skore` displays. + + .. note:: + This class is a Python protocol and it is not intended to be inherited from. + """ def plot(self, **kwargs: Any) -> None: """Display a figure containing the information of the display.""" @@ -21,3 +40,248 @@ def frame(self, **kwargs: Any) -> pd.DataFrame: DataFrame A DataFrame containing the data used to create the display. """ + + def help(self) -> None: + """Display available attributes and methods using rich.""" + + +######################################################################################## +# Plotting related mixins +######################################################################################## + + +class PlotBackendMixin: + """Mixin class for Displays to dispatch plotting to the configured backend.""" + + def _plot(self, **kwargs): + """Dispatch plotting to the configured backend.""" + plot_backend = get_config()["plot_backend"] + if plot_backend == "matplotlib": + return self._plot_matplotlib(**kwargs) + elif plot_backend == "plotly": + return self._plot_plotly(**kwargs) + else: + raise NotImplementedError( + f"Plotting backend {plot_backend} not available. " + f"Available options are {PlotBackend.__args__}." + ) + + def _plot_plotly(self, **kwargs): + raise NotImplementedError( + "Plotting with plotly is not supported for this Display." + ) + + +DEFAULT_STYLE = { + "font.size": 14, + "axes.labelsize": 14, + "axes.titlesize": 14, + "xtick.labelsize": 13, + "ytick.labelsize": 13, + "legend.fontsize": 10, + "legend.title_fontsize": 11, + "axes.linewidth": 1.25, + "grid.linewidth": 1.25, + "lines.linewidth": 1.75, + "lines.markersize": 6, + "patch.linewidth": 1.25, + "xtick.major.width": 1.5, + "ytick.major.width": 1.5, + "xtick.minor.width": 1.25, + "ytick.minor.width": 1.25, + "xtick.major.size": 7, + "ytick.major.size": 7, + "xtick.minor.size": 5, + "ytick.minor.size": 5, + "legend.loc": "upper left", + "legend.borderaxespad": 0, +} + + +class StyleDisplayMixin: + """Mixin to control the style plot of a display.""" + + @property + def _style_params(self) -> list[str]: + """Get the list of available style parameters. + + Returns + ------- + list + List of style parameter names (without '_default_' prefix). + """ + prefix = "_default_" + suffix = "_kwargs" + return [ + attr[len(prefix) :] + for attr in dir(self) + if attr.startswith(prefix) and attr.endswith(suffix) + ] + + def set_style(self, **kwargs: Any): + """Set the style parameters for the display. + + Parameters + ---------- + **kwargs : dict + Style parameters to set. Each parameter name should correspond to a + a style attribute passed to the plot method of the display. + + Returns + ------- + self : object + Returns the instance itself. + + Raises + ------ + ValueError + If a style parameter is unknown. + """ + for param_name, param_value in kwargs.items(): + default_attr = f"_default_{param_name}" + if not hasattr(self, default_attr): + raise ValueError( + f"Unknown style parameter: {param_name}. " + f"The parameter name should be one of {self._style_params}." + ) + setattr(self, default_attr, param_value) + return self + + @staticmethod + def style_plot(plot_func: Callable) -> Callable: + """Apply consistent style to skore displays. + + This decorator: + 1. Applies default style settings + 2. Executes `plot_func` + 3. Applies `tight_layout` + + Parameters + ---------- + plot_func : callable + The plot function to be decorated. + + Returns + ------- + callable + The decorated plot function. + """ + + @wraps(plot_func) + def wrapper(self, *args: Any, **kwargs: Any) -> Any: + # We need to manually handle setting the style of the parameters because + # `plt.style.context` has a side effect with the interactive mode. + # See https://github.com/matplotlib/matplotlib/issues/25041 + original_params = {key: plt.rcParams[key] for key in DEFAULT_STYLE} + plt.rcParams.update(DEFAULT_STYLE) + try: + result = plot_func(self, *args, **kwargs) + finally: + plt.rcParams.update(original_params) + return result + + return wrapper + + +######################################################################################## +# General purpose mixins +######################################################################################## + + +class HelpDisplayMixin: + """Mixin class to add help functionality to a class.""" + + estimator_name: str # defined in the concrete display class + + def _get_attributes_for_help(self) -> list[str]: + """Get the attributes ending with '_' to display in help.""" + attributes = [] + for name in dir(self): + if name.endswith("_") and not name.startswith("_"): + attributes.append(f".{name}") + return sorted(attributes) + + def _get_methods_for_help(self) -> list[tuple[str, Any]]: + """Get the public methods to display in help.""" + methods = inspect.getmembers(self, predicate=inspect.ismethod) + filtered_methods = [] + for name, method in methods: + is_private = name.startswith("_") + is_class_method = inspect.ismethod(method) and method.__self__ is type(self) + is_help_method = name == "help" + if not (is_private or is_class_method or is_help_method): + filtered_methods.append((f".{name}(...)", method)) + return sorted(filtered_methods) + + def _create_help_tree(self) -> Tree: + """Create a rich Tree with attributes and methods.""" + tree = Tree("display") + + attributes = self._get_attributes_for_help() + attr_branch = tree.add("[bold cyan] Attributes[/bold cyan]") + # Ensure figure_ and ax_ are first + sorted_attrs = sorted(attributes) + if ("figure_" in sorted_attrs) and ("ax_" in sorted_attrs): + sorted_attrs.remove(".ax_") + sorted_attrs.remove(".figure_") + sorted_attrs = [".figure_", ".ax_"] + [ + attr for attr in sorted_attrs if attr not in [".figure_", ".ax_"] + ] + for attr in sorted_attrs: + attr_branch.add(attr) + + methods = self._get_methods_for_help() + method_branch = tree.add("[bold cyan]Methods[/bold cyan]") + for name, method in methods: + description = ( + method.__doc__.split("\n")[0] + if method.__doc__ + else "No description available" + ) + method_branch.add(f"{name} - {description}") + + return tree + + def _create_help_panel(self) -> Panel: + return Panel( + self._create_help_tree(), + title=f"[bold cyan]{self.__class__.__name__} [/bold cyan]", + border_style="orange1", + expand=False, + ) + + def help(self) -> None: + """Display available attributes and methods using rich.""" + from skore import console # avoid circular import + + console.print(self._create_help_panel()) + + def __str__(self) -> str: + """Return a string representation using rich.""" + string_buffer = StringIO() + console = Console(file=string_buffer, force_terminal=False) + console.print( + Panel( + "Get guidance using the help() method", + title=f"[cyan]{self.__class__.__name__}[/cyan]", + border_style="orange1", + expand=False, + ) + ) + return string_buffer.getvalue() + + def __repr__(self) -> str: + """Return a string representation using rich.""" + string_buffer = StringIO() + console = Console(file=string_buffer, force_terminal=False) + console.print(f"[cyan]skore.{self.__class__.__name__}(...)[/cyan]") + return string_buffer.getvalue() + + +######################################################################################## +# Display mixin inheriting from the different mixins +######################################################################################## + + +class DisplayMixin(HelpDisplayMixin, PlotBackendMixin, StyleDisplayMixin): + """Mixin inheriting help, plotting, and style functionality.""" diff --git a/skore/src/skore/_sklearn/_plot/data/table_report.py b/skore/src/skore/_sklearn/_plot/data/table_report.py index e062675a8a..7db9b41299 100644 --- a/skore/src/skore/_sklearn/_plot/data/table_report.py +++ b/skore/src/skore/_sklearn/_plot/data/table_report.py @@ -16,10 +16,8 @@ ) from skore._externals._skrub_compat import sbd -from skore._sklearn._plot.style import StyleDisplayMixin +from skore._sklearn._plot.base import DisplayMixin from skore._sklearn._plot.utils import ( - HelpDisplayMixin, - PlotBackendMixin, _adjust_fig_size, _rotate_ticklabels, _validate_style_kwargs, @@ -162,9 +160,7 @@ def _resize_categorical_axis( _adjust_fig_size(figure, ax, target_width, target_height) -class TableReportDisplay( - StyleDisplayMixin, HelpDisplayMixin, ReprHTMLMixin, PlotBackendMixin -): +class TableReportDisplay(ReprHTMLMixin, DisplayMixin): """Display reporting information about a given dataset. This display summarizes the dataset and provides a way to visualize @@ -222,8 +218,8 @@ def _compute_data_for_display(cls, dataset: pd.DataFrame) -> "TableReportDisplay """ return cls(summarize_dataframe(dataset, with_plots=True, title=None, verbose=0)) - @StyleDisplayMixin.style_plot - def _plot_matplotlib( + @DisplayMixin.style_plot + def plot( self, *, x: str | None = None, @@ -305,6 +301,34 @@ def _plot_matplotlib( >>> display = report.data.analyze() >>> display.plot(kind="corr") """ + return self._plot( + x=x, + y=y, + hue=hue, + kind=kind, + top_k_categories=top_k_categories, + scatterplot_kwargs=scatterplot_kwargs, + stripplot_kwargs=stripplot_kwargs, + boxplot_kwargs=boxplot_kwargs, + heatmap_kwargs=heatmap_kwargs, + histplot_kwargs=histplot_kwargs, + ) + + def _plot_matplotlib( + self, + *, + x: str | None = None, + y: str | None = None, + hue: str | None = None, + kind: Literal["dist", "corr"] = "dist", + top_k_categories: int = 20, + scatterplot_kwargs: dict[str, Any] | None = None, + stripplot_kwargs: dict[str, Any] | None = None, + boxplot_kwargs: dict[str, Any] | None = None, + heatmap_kwargs: dict[str, Any] | None = None, + histplot_kwargs: dict[str, Any] | None = None, + ) -> None: + """Matplotlib implementation of the `plot` method.""" self.figure_, self.ax_ = plt.subplots() if kind == "dist": match (x is None, y is None, hue is None): diff --git a/skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py b/skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py index 64dfba8df4..973d97a789 100644 --- a/skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py +++ b/skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py @@ -2,12 +2,10 @@ import numpy as np from sklearn.metrics import confusion_matrix as sklearn_confusion_matrix -from skore._sklearn._plot.base import Display -from skore._sklearn._plot.style import StyleDisplayMixin -from skore._sklearn._plot.utils import PlotBackendMixin +from skore._sklearn._plot.base import DisplayMixin -class ConfusionMatrixDisplay(PlotBackendMixin, Display): +class ConfusionMatrixDisplay(DisplayMixin): """Display for confusion matrix. Parameters @@ -45,7 +43,6 @@ class ConfusionMatrixDisplay(PlotBackendMixin, Display): confusion matrix. """ - @StyleDisplayMixin.style_plot def __init__( self, confusion_matrix, @@ -64,7 +61,8 @@ def __init__( self.ax_ = None self.text_ = None - def _plot_matplotlib(self, ax=None, *, cmap="Blues", colorbar=True, **kwargs): + @DisplayMixin.style_plot + def plot(self, ax=None, *, cmap="Blues", colorbar=True, **kwargs): """Plot the confusion matrix. Parameters @@ -87,6 +85,9 @@ def _plot_matplotlib(self, ax=None, *, cmap="Blues", colorbar=True, **kwargs): self : ConfusionMatrixDisplay Configured with the confusion matrix. """ + return self._plot(ax=ax, cmap=cmap, colorbar=colorbar, **kwargs) + + def _plot_matplotlib(self, ax=None, *, cmap="Blues", colorbar=True, **kwargs): if self.normalize not in (None, "true", "pred", "all"): raise ValueError( "normalize must be one of None, 'true', 'pred', 'all'; " diff --git a/skore/src/skore/_sklearn/_plot/metrics/feature_importance_display.py b/skore/src/skore/_sklearn/_plot/metrics/feature_importance_display.py index 10a145e474..9b9d5cb0fe 100644 --- a/skore/src/skore/_sklearn/_plot/metrics/feature_importance_display.py +++ b/skore/src/skore/_sklearn/_plot/metrics/feature_importance_display.py @@ -1,13 +1,9 @@ import matplotlib.pyplot as plt -from skore._sklearn._plot.base import Display -from skore._sklearn._plot.style import StyleDisplayMixin -from skore._sklearn._plot.utils import HelpDisplayMixin, PlotBackendMixin +from skore._sklearn._plot.base import DisplayMixin -class FeatureImportanceDisplay( - HelpDisplayMixin, StyleDisplayMixin, PlotBackendMixin, Display -): +class FeatureImportanceDisplay(DisplayMixin): """Feature importance display. Each report type produces its own output frame and plot. @@ -102,7 +98,17 @@ def _frame_comparison_report(self): return pd.concat(self._coefficient_data, axis=1) - @StyleDisplayMixin.style_plot + @DisplayMixin.style_plot + def plot(self, **kwargs) -> None: + """Plot the coefficients of linear models. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments to be passed to the plot method. + """ + return self._plot(**kwargs) + def _plot_matplotlib(self, **kwargs): from skore._sklearn._comparison import ComparisonReport from skore._sklearn._cross_validation import CrossValidationReport diff --git a/skore/src/skore/_sklearn/_plot/metrics/metrics_summary_display.py b/skore/src/skore/_sklearn/_plot/metrics/metrics_summary_display.py index a6846294ff..9220960f0d 100644 --- a/skore/src/skore/_sklearn/_plot/metrics/metrics_summary_display.py +++ b/skore/src/skore/_sklearn/_plot/metrics/metrics_summary_display.py @@ -1,8 +1,7 @@ -from skore._sklearn._plot.style import StyleDisplayMixin -from skore._sklearn._plot.utils import HelpDisplayMixin +from skore._sklearn._plot.base import DisplayMixin -class MetricsSummaryDisplay(HelpDisplayMixin, StyleDisplayMixin): +class MetricsSummaryDisplay(DisplayMixin): """Display for summarize. An instance of this class will be created by `Report.metrics.summarize()`. @@ -22,6 +21,7 @@ def frame(self): """ return self.summarize_data - @StyleDisplayMixin.style_plot + @DisplayMixin.style_plot def plot(self): + """Not yet implemented.""" raise NotImplementedError diff --git a/skore/src/skore/_sklearn/_plot/metrics/precision_recall_curve.py b/skore/src/skore/_sklearn/_plot/metrics/precision_recall_curve.py index 2115d6c3fc..b2aba04adb 100644 --- a/skore/src/skore/_sklearn/_plot/metrics/precision_recall_curve.py +++ b/skore/src/skore/_sklearn/_plot/metrics/precision_recall_curve.py @@ -12,11 +12,9 @@ from sklearn.metrics import average_precision_score, precision_recall_curve from sklearn.preprocessing import LabelBinarizer -from skore._sklearn._plot.style import StyleDisplayMixin +from skore._sklearn._plot.base import DisplayMixin from skore._sklearn._plot.utils import ( LINESTYLE, - HelpDisplayMixin, - PlotBackendMixin, _ClassifierCurveDisplayMixin, _despine_matplotlib_axis, _validate_style_kwargs, @@ -45,9 +43,7 @@ def _set_axis_labels(ax: Axes, info_pos_label: str | None) -> None: MAX_N_LABELS = 5 -class PrecisionRecallCurveDisplay( - StyleDisplayMixin, HelpDisplayMixin, _ClassifierCurveDisplayMixin, PlotBackendMixin -): +class PrecisionRecallCurveDisplay(_ClassifierCurveDisplayMixin, DisplayMixin): """Precision Recall visualization. An instance of this class should be created by @@ -630,8 +626,8 @@ def _plot_comparison_cross_validation( return self.ax_, lines, info_pos_label - @StyleDisplayMixin.style_plot - def _plot_matplotlib( + @DisplayMixin.style_plot + def plot( self, *, estimator_name: str | None = None, @@ -640,8 +636,6 @@ def _plot_matplotlib( ) -> None: """Plot visualization. - Extra keyword arguments will be passed to matplotlib's `plot`. - Parameters ---------- estimator_name : str, default=None @@ -679,6 +673,20 @@ def _plot_matplotlib( >>> display = report.metrics.precision_recall() >>> display.plot(pr_curve_kwargs={"color": "tab:red"}) """ + return self._plot( + estimator_name=estimator_name, + pr_curve_kwargs=pr_curve_kwargs, + despine=despine, + ) + + def _plot_matplotlib( + self, + *, + estimator_name: str | None = None, + pr_curve_kwargs: dict[str, Any] | list[dict[str, Any]] | None = None, + despine: bool = True, + ) -> None: + """Matplotlib implementation of the `plot` method.""" if ( self.report_type == "comparison-cross-validation" and self.ml_task == "multiclass-classification" diff --git a/skore/src/skore/_sklearn/_plot/metrics/prediction_error.py b/skore/src/skore/_sklearn/_plot/metrics/prediction_error.py index d489c16062..546e09cc8e 100644 --- a/skore/src/skore/_sklearn/_plot/metrics/prediction_error.py +++ b/skore/src/skore/_sklearn/_plot/metrics/prediction_error.py @@ -10,10 +10,8 @@ from sklearn.utils.validation import _num_samples, check_array from skore._externals._sklearn_compat import _safe_indexing -from skore._sklearn._plot.style import StyleDisplayMixin +from skore._sklearn._plot.base import DisplayMixin from skore._sklearn._plot.utils import ( - HelpDisplayMixin, - PlotBackendMixin, _despine_matplotlib_axis, _validate_style_kwargs, sample_mpl_colormap, @@ -25,7 +23,7 @@ MAX_N_LABELS = 6 # 5 + 1 for the perfect model line -class PredictionErrorDisplay(StyleDisplayMixin, HelpDisplayMixin, PlotBackendMixin): +class PredictionErrorDisplay(DisplayMixin): """Visualization of the prediction error of a regression model. This tool can display "residuals vs predicted" or "actual vs predicted" @@ -522,8 +520,8 @@ def _plot_comparison_cross_validation( return scatter - @StyleDisplayMixin.style_plot - def _plot_matplotlib( + @DisplayMixin.style_plot + def plot( self, *, estimator_name: str | None = None, @@ -578,6 +576,26 @@ def _plot_matplotlib( >>> display = report.metrics.prediction_error() >>> display.plot(kind="actual_vs_predicted") """ + return self._plot( + estimator_name=estimator_name, + kind=kind, + data_points_kwargs=data_points_kwargs, + perfect_model_kwargs=perfect_model_kwargs, + despine=despine, + ) + + def _plot_matplotlib( + self, + *, + estimator_name: str | None = None, + kind: Literal[ + "actual_vs_predicted", "residual_vs_predicted" + ] = "residual_vs_predicted", + data_points_kwargs: dict[str, Any] | list[dict[str, Any]] | None = None, + perfect_model_kwargs: dict[str, Any] | None = None, + despine: bool = True, + ) -> None: + """Matplolib implementation of the `plot` method.""" expected_kind = ("actual_vs_predicted", "residual_vs_predicted") if kind not in expected_kind: raise ValueError( diff --git a/skore/src/skore/_sklearn/_plot/metrics/roc_curve.py b/skore/src/skore/_sklearn/_plot/metrics/roc_curve.py index 33a8288a3a..96dc0031ef 100644 --- a/skore/src/skore/_sklearn/_plot/metrics/roc_curve.py +++ b/skore/src/skore/_sklearn/_plot/metrics/roc_curve.py @@ -11,11 +11,9 @@ from sklearn.metrics import auc, roc_curve from sklearn.preprocessing import LabelBinarizer -from skore._sklearn._plot.style import StyleDisplayMixin +from skore._sklearn._plot.base import DisplayMixin from skore._sklearn._plot.utils import ( LINESTYLE, - HelpDisplayMixin, - PlotBackendMixin, _ClassifierCurveDisplayMixin, _despine_matplotlib_axis, _validate_style_kwargs, @@ -63,9 +61,7 @@ def _add_chance_level( return cast(Line2D, chance_level) -class RocCurveDisplay( - StyleDisplayMixin, HelpDisplayMixin, _ClassifierCurveDisplayMixin, PlotBackendMixin -): +class RocCurveDisplay(_ClassifierCurveDisplayMixin, DisplayMixin): """ROC Curve visualization. An instance of this class should be created by `EstimatorReport.metrics.roc()`. @@ -703,8 +699,8 @@ def _plot_comparison_cross_validation( return self.ax_, lines, info_pos_label - @StyleDisplayMixin.style_plot - def _plot_matplotlib( + @DisplayMixin.style_plot + def plot( self, *, estimator_name: str | None = None, @@ -750,6 +746,24 @@ def _plot_matplotlib( >>> display = report.metrics.roc() >>> display.plot(roc_curve_kwargs={"color": "tab:red"}) """ + return self._plot( + estimator_name=estimator_name, + roc_curve_kwargs=roc_curve_kwargs, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, + despine=despine, + ) + + def _plot_matplotlib( + self, + *, + estimator_name: str | None = None, + roc_curve_kwargs: dict[str, Any] | list[dict[str, Any]] | None = None, + plot_chance_level: bool = True, + chance_level_kwargs: dict[str, Any] | None = None, + despine: bool = True, + ) -> None: + """Matplotlib implementation of the `plot` method.""" if ( self.report_type == "comparison-cross-validation" and self.ml_task == "multiclass-classification" diff --git a/skore/src/skore/_sklearn/_plot/style.py b/skore/src/skore/_sklearn/_plot/style.py deleted file mode 100644 index cdb2a4646e..0000000000 --- a/skore/src/skore/_sklearn/_plot/style.py +++ /dev/null @@ -1,115 +0,0 @@ -from collections.abc import Callable -from functools import wraps -from typing import Any - -import matplotlib.pyplot as plt - -DEFAULT_STYLE = { - "font.size": 14, - "axes.labelsize": 14, - "axes.titlesize": 14, - "xtick.labelsize": 13, - "ytick.labelsize": 13, - "legend.fontsize": 10, - "legend.title_fontsize": 11, - "axes.linewidth": 1.25, - "grid.linewidth": 1.25, - "lines.linewidth": 1.75, - "lines.markersize": 6, - "patch.linewidth": 1.25, - "xtick.major.width": 1.5, - "ytick.major.width": 1.5, - "xtick.minor.width": 1.25, - "ytick.minor.width": 1.25, - "xtick.major.size": 7, - "ytick.major.size": 7, - "xtick.minor.size": 5, - "ytick.minor.size": 5, - "legend.loc": "upper left", - "legend.borderaxespad": 0, -} - - -class StyleDisplayMixin: - """Mixin to control the style plot of a display.""" - - @property - def _style_params(self) -> list[str]: - """Get the list of available style parameters. - - Returns - ------- - list - List of style parameter names (without '_default_' prefix). - """ - prefix = "_default_" - suffix = "_kwargs" - return [ - attr[len(prefix) :] - for attr in dir(self) - if attr.startswith(prefix) and attr.endswith(suffix) - ] - - def set_style(self, **kwargs: Any): - """Set the style parameters for the display. - - Parameters - ---------- - **kwargs : dict - Style parameters to set. Each parameter name should correspond to a - a style attribute passed to the plot method of the display. - - Returns - ------- - self : object - Returns the instance itself. - - Raises - ------ - ValueError - If a style parameter is unknown. - """ - for param_name, param_value in kwargs.items(): - default_attr = f"_default_{param_name}" - if not hasattr(self, default_attr): - raise ValueError( - f"Unknown style parameter: {param_name}. " - f"The parameter name should be one of {self._style_params}." - ) - setattr(self, default_attr, param_value) - return self - - @staticmethod - def style_plot(plot_func: Callable) -> Callable: - """Apply consistent style to skore displays. - - This decorator: - 1. Applies default style settings - 2. Executes `plot_func` - 3. Applies `tight_layout` - - Parameters - ---------- - plot_func : callable - The plot function to be decorated. - - Returns - ------- - callable - The decorated plot function. - """ - - @wraps(plot_func) - def wrapper(self, *args: Any, **kwargs: Any) -> Any: - # We need to manually handle setting the style of the parameters because - # `plt.style.context` has a side effect with the interactive mode. - # See https://github.com/matplotlib/matplotlib/issues/25041 - original_params = {key: plt.rcParams[key] for key in DEFAULT_STYLE} - plt.rcParams.update(DEFAULT_STYLE) - try: - result = plot_func(self, *args, **kwargs) - finally: - plt.rcParams.update(original_params) - return result - - return wrapper diff --git a/skore/src/skore/_sklearn/_plot/utils.py b/skore/src/skore/_sklearn/_plot/utils.py index 37b31cf034..95bb4e9e84 100644 --- a/skore/src/skore/_sklearn/_plot/utils.py +++ b/skore/src/skore/_sklearn/_plot/utils.py @@ -1,6 +1,4 @@ -import inspect from collections.abc import Sequence -from io import StringIO from typing import Any import numpy as np @@ -8,18 +6,13 @@ from matplotlib.axes import Axes from matplotlib.colors import Colormap from matplotlib.figure import Figure -from rich.console import Console -from rich.panel import Panel -from rich.tree import Tree from sklearn.utils.validation import ( _check_pos_label_consistency, check_consistent_length, ) -from skore._config import get_config from skore._sklearn.types import ( MLTask, - PlotBackend, PositiveLabel, ReportType, YPlotData, @@ -46,96 +39,6 @@ ] -class HelpDisplayMixin: - """Mixin class to add help functionality to a class.""" - - estimator_name: str # defined in the concrete display class - - def _get_attributes_for_help(self) -> list[str]: - """Get the attributes ending with '_' to display in help.""" - attributes = [] - for name in dir(self): - if name.endswith("_") and not name.startswith("_"): - attributes.append(f".{name}") - return sorted(attributes) - - def _get_methods_for_help(self) -> list[tuple[str, Any]]: - """Get the public methods to display in help.""" - methods = inspect.getmembers(self, predicate=inspect.ismethod) - filtered_methods = [] - for name, method in methods: - is_private = name.startswith("_") - is_class_method = inspect.ismethod(method) and method.__self__ is type(self) - is_help_method = name == "help" - if not (is_private or is_class_method or is_help_method): - filtered_methods.append((f".{name}(...)", method)) - return sorted(filtered_methods) - - def _create_help_tree(self) -> Tree: - """Create a rich Tree with attributes and methods.""" - tree = Tree("display") - - attributes = self._get_attributes_for_help() - attr_branch = tree.add("[bold cyan] Attributes[/bold cyan]") - # Ensure figure_ and ax_ are first - sorted_attrs = sorted(attributes) - if ("figure_" in sorted_attrs) and ("ax_" in sorted_attrs): - sorted_attrs.remove(".ax_") - sorted_attrs.remove(".figure_") - sorted_attrs = [".figure_", ".ax_"] + [ - attr for attr in sorted_attrs if attr not in [".figure_", ".ax_"] - ] - for attr in sorted_attrs: - attr_branch.add(attr) - - methods = self._get_methods_for_help() - method_branch = tree.add("[bold cyan]Methods[/bold cyan]") - for name, method in methods: - description = ( - method.__doc__.split("\n")[0] - if method.__doc__ - else "No description available" - ) - method_branch.add(f"{name} - {description}") - - return tree - - def _create_help_panel(self) -> Panel: - return Panel( - self._create_help_tree(), - title=f"[bold cyan]{self.__class__.__name__} [/bold cyan]", - border_style="orange1", - expand=False, - ) - - def help(self) -> None: - """Display available attributes and methods using rich.""" - from skore import console # avoid circular import - - console.print(self._create_help_panel()) - - def __str__(self) -> str: - """Return a string representation using rich.""" - string_buffer = StringIO() - console = Console(file=string_buffer, force_terminal=False) - console.print( - Panel( - "Get guidance using the help() method", - title=f"[cyan]{self.__class__.__name__}[/cyan]", - border_style="orange1", - expand=False, - ) - ) - return string_buffer.getvalue() - - def __repr__(self) -> str: - """Return a string representation using rich.""" - string_buffer = StringIO() - console = Console(file=string_buffer, force_terminal=False) - console.print(f"[cyan]skore.{self.__class__.__name__}(...)[/cyan]") - return string_buffer.getvalue() - - class _ClassifierCurveDisplayMixin: """Mixin class to be used in Displays requiring a binary classifier. @@ -333,28 +236,6 @@ def _adjust_fig_size( fig.set_size_inches((width, height)) -class PlotBackendMixin: - """Mixin class for Displays to dispatch plotting to the configured backend.""" - - def plot(self, **kwargs): - """Show as a plot.""" - plot_backend = get_config()["plot_backend"] - if plot_backend == "matplotlib": - return self._plot_matplotlib(**kwargs) - elif plot_backend == "plotly": - return self._plot_plotly(**kwargs) - else: - raise NotImplementedError( - f"Plotting backend {plot_backend} not available. " - f"Available options are {PlotBackend.__args__}." - ) - - def _plot_plotly(self, **kwargs): - raise NotImplementedError( - "Plotting with plotly is not supported for this Display." - ) - - def _despine_matplotlib_axis( ax: Axes, *, diff --git a/skore/tests/unit/displays/test_style.py b/skore/tests/unit/displays/test_style.py index 7eb675abda..5357f0d31d 100644 --- a/skore/tests/unit/displays/test_style.py +++ b/skore/tests/unit/displays/test_style.py @@ -1,5 +1,5 @@ import pytest -from skore._sklearn._plot.style import StyleDisplayMixin +from skore._sklearn._plot.base import StyleDisplayMixin class TestDisplay(StyleDisplayMixin):