Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 265 additions & 1 deletion skore/src/skore/_sklearn/_plot/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,30 @@
import inspect
from collections.abc import Callable
from functools import wraps
from io import StringIO
from typing import Any, Protocol, runtime_checkable

import matplotlib.pyplot as plt
import pandas as pd
from rich.console import Console
from rich.panel import Panel
from rich.tree import Tree

from skore._config import get_config
from skore._sklearn.types import PlotBackend

########################################################################################
# Display protocol
########################################################################################


@runtime_checkable
class Display(Protocol):
"""Protocol specifying the common API for all `skore` displays."""
"""Protocol specifying the common API for all `skore` displays.

.. note::
This class is a Python protocol and it is not intended to be inherited from.
"""

def plot(self, **kwargs: Any) -> None:
"""Display a figure containing the information of the display."""
Expand All @@ -21,3 +40,248 @@ def frame(self, **kwargs: Any) -> pd.DataFrame:
DataFrame
A DataFrame containing the data used to create the display.
"""

def help(self) -> None:
"""Display available attributes and methods using rich."""


########################################################################################
# Plotting related mixins
########################################################################################


class PlotBackendMixin:
"""Mixin class for Displays to dispatch plotting to the configured backend."""

def _plot(self, **kwargs):
"""Dispatch plotting to the configured backend."""
plot_backend = get_config()["plot_backend"]
if plot_backend == "matplotlib":
return self._plot_matplotlib(**kwargs)
elif plot_backend == "plotly":
return self._plot_plotly(**kwargs)
else:
raise NotImplementedError(
f"Plotting backend {plot_backend} not available. "
f"Available options are {PlotBackend.__args__}."
)

def _plot_plotly(self, **kwargs):
raise NotImplementedError(
"Plotting with plotly is not supported for this Display."
)


DEFAULT_STYLE = {
"font.size": 14,
"axes.labelsize": 14,
"axes.titlesize": 14,
"xtick.labelsize": 13,
"ytick.labelsize": 13,
"legend.fontsize": 10,
"legend.title_fontsize": 11,
"axes.linewidth": 1.25,
"grid.linewidth": 1.25,
"lines.linewidth": 1.75,
"lines.markersize": 6,
"patch.linewidth": 1.25,
"xtick.major.width": 1.5,
"ytick.major.width": 1.5,
"xtick.minor.width": 1.25,
"ytick.minor.width": 1.25,
"xtick.major.size": 7,
"ytick.major.size": 7,
"xtick.minor.size": 5,
"ytick.minor.size": 5,
"legend.loc": "upper left",
"legend.borderaxespad": 0,
}


class StyleDisplayMixin:
"""Mixin to control the style plot of a display."""

@property
def _style_params(self) -> list[str]:
"""Get the list of available style parameters.

Returns
-------
list
List of style parameter names (without '_default_' prefix).
"""
prefix = "_default_"
suffix = "_kwargs"
return [
attr[len(prefix) :]
for attr in dir(self)
if attr.startswith(prefix) and attr.endswith(suffix)
]

def set_style(self, **kwargs: Any):
"""Set the style parameters for the display.

Parameters
----------
**kwargs : dict
Style parameters to set. Each parameter name should correspond to a
a style attribute passed to the plot method of the display.

Returns
-------
self : object
Returns the instance itself.

Raises
------
ValueError
If a style parameter is unknown.
"""
for param_name, param_value in kwargs.items():
default_attr = f"_default_{param_name}"
if not hasattr(self, default_attr):
raise ValueError(
f"Unknown style parameter: {param_name}. "
f"The parameter name should be one of {self._style_params}."
)
setattr(self, default_attr, param_value)
return self

@staticmethod
def style_plot(plot_func: Callable) -> Callable:
"""Apply consistent style to skore displays.

This decorator:
1. Applies default style settings
2. Executes `plot_func`
3. Applies `tight_layout`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it apply tight layout?


Parameters
----------
plot_func : callable
The plot function to be decorated.

Returns
-------
callable
The decorated plot function.
"""

@wraps(plot_func)
def wrapper(self, *args: Any, **kwargs: Any) -> Any:
# We need to manually handle setting the style of the parameters because
# `plt.style.context` has a side effect with the interactive mode.
# See https://github.com/matplotlib/matplotlib/issues/25041
original_params = {key: plt.rcParams[key] for key in DEFAULT_STYLE}
plt.rcParams.update(DEFAULT_STYLE)
try:
result = plot_func(self, *args, **kwargs)
finally:
plt.rcParams.update(original_params)
return result

return wrapper


########################################################################################
# General purpose mixins
########################################################################################


class HelpDisplayMixin:
"""Mixin class to add help functionality to a class."""

estimator_name: str # defined in the concrete display class

def _get_attributes_for_help(self) -> list[str]:
"""Get the attributes ending with '_' to display in help."""
attributes = []
for name in dir(self):
if name.endswith("_") and not name.startswith("_"):
attributes.append(f".{name}")
return sorted(attributes)

def _get_methods_for_help(self) -> list[tuple[str, Any]]:
"""Get the public methods to display in help."""
methods = inspect.getmembers(self, predicate=inspect.ismethod)
filtered_methods = []
for name, method in methods:
is_private = name.startswith("_")
is_class_method = inspect.ismethod(method) and method.__self__ is type(self)
is_help_method = name == "help"
if not (is_private or is_class_method or is_help_method):
filtered_methods.append((f".{name}(...)", method))
return sorted(filtered_methods)

def _create_help_tree(self) -> Tree:
"""Create a rich Tree with attributes and methods."""
tree = Tree("display")

attributes = self._get_attributes_for_help()
attr_branch = tree.add("[bold cyan] Attributes[/bold cyan]")
# Ensure figure_ and ax_ are first
sorted_attrs = sorted(attributes)
if ("figure_" in sorted_attrs) and ("ax_" in sorted_attrs):
sorted_attrs.remove(".ax_")
sorted_attrs.remove(".figure_")
sorted_attrs = [".figure_", ".ax_"] + [
attr for attr in sorted_attrs if attr not in [".figure_", ".ax_"]
]
for attr in sorted_attrs:
attr_branch.add(attr)

methods = self._get_methods_for_help()
method_branch = tree.add("[bold cyan]Methods[/bold cyan]")
for name, method in methods:
description = (
method.__doc__.split("\n")[0]
if method.__doc__
else "No description available"
)
method_branch.add(f"{name} - {description}")

return tree

def _create_help_panel(self) -> Panel:
return Panel(
self._create_help_tree(),
title=f"[bold cyan]{self.__class__.__name__} [/bold cyan]",
border_style="orange1",
expand=False,
)

def help(self) -> None:
"""Display available attributes and methods using rich."""
from skore import console # avoid circular import

console.print(self._create_help_panel())

def __str__(self) -> str:
"""Return a string representation using rich."""
string_buffer = StringIO()
console = Console(file=string_buffer, force_terminal=False)
console.print(
Panel(
"Get guidance using the help() method",
title=f"[cyan]{self.__class__.__name__}[/cyan]",
border_style="orange1",
expand=False,
)
)
return string_buffer.getvalue()

def __repr__(self) -> str:
"""Return a string representation using rich."""
string_buffer = StringIO()
console = Console(file=string_buffer, force_terminal=False)
console.print(f"[cyan]skore.{self.__class__.__name__}(...)[/cyan]")
return string_buffer.getvalue()


########################################################################################
# Display mixin inheriting from the different mixins
########################################################################################


class DisplayMixin(HelpDisplayMixin, PlotBackendMixin, StyleDisplayMixin):
"""Mixin inheriting help, plotting, and style functionality."""
40 changes: 32 additions & 8 deletions skore/src/skore/_sklearn/_plot/data/table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
)

from skore._externals._skrub_compat import sbd
from skore._sklearn._plot.style import StyleDisplayMixin
from skore._sklearn._plot.base import DisplayMixin
from skore._sklearn._plot.utils import (
HelpDisplayMixin,
PlotBackendMixin,
_adjust_fig_size,
_rotate_ticklabels,
_validate_style_kwargs,
Expand Down Expand Up @@ -162,9 +160,7 @@ def _resize_categorical_axis(
_adjust_fig_size(figure, ax, target_width, target_height)


class TableReportDisplay(
StyleDisplayMixin, HelpDisplayMixin, ReprHTMLMixin, PlotBackendMixin
):
class TableReportDisplay(ReprHTMLMixin, DisplayMixin):
"""Display reporting information about a given dataset.

This display summarizes the dataset and provides a way to visualize
Expand Down Expand Up @@ -222,8 +218,8 @@ def _compute_data_for_display(cls, dataset: pd.DataFrame) -> "TableReportDisplay
"""
return cls(summarize_dataframe(dataset, with_plots=True, title=None, verbose=0))

@StyleDisplayMixin.style_plot
def _plot_matplotlib(
@DisplayMixin.style_plot
def plot(
self,
*,
x: str | None = None,
Expand Down Expand Up @@ -305,6 +301,34 @@ def _plot_matplotlib(
>>> display = report.data.analyze()
>>> display.plot(kind="corr")
"""
return self._plot(
x=x,
y=y,
hue=hue,
kind=kind,
top_k_categories=top_k_categories,
scatterplot_kwargs=scatterplot_kwargs,
stripplot_kwargs=stripplot_kwargs,
boxplot_kwargs=boxplot_kwargs,
heatmap_kwargs=heatmap_kwargs,
histplot_kwargs=histplot_kwargs,
)

def _plot_matplotlib(
self,
*,
x: str | None = None,
y: str | None = None,
hue: str | None = None,
kind: Literal["dist", "corr"] = "dist",
top_k_categories: int = 20,
scatterplot_kwargs: dict[str, Any] | None = None,
stripplot_kwargs: dict[str, Any] | None = None,
boxplot_kwargs: dict[str, Any] | None = None,
heatmap_kwargs: dict[str, Any] | None = None,
histplot_kwargs: dict[str, Any] | None = None,
) -> None:
"""Matplotlib implementation of the `plot` method."""
self.figure_, self.ax_ = plt.subplots()
if kind == "dist":
match (x is None, y is None, hue is None):
Expand Down
13 changes: 7 additions & 6 deletions skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
import numpy as np
from sklearn.metrics import confusion_matrix as sklearn_confusion_matrix

from skore._sklearn._plot.base import Display
from skore._sklearn._plot.style import StyleDisplayMixin
from skore._sklearn._plot.utils import PlotBackendMixin
from skore._sklearn._plot.base import DisplayMixin


class ConfusionMatrixDisplay(PlotBackendMixin, Display):
class ConfusionMatrixDisplay(DisplayMixin):
"""Display for confusion matrix.

Parameters
Expand Down Expand Up @@ -45,7 +43,6 @@ class ConfusionMatrixDisplay(PlotBackendMixin, Display):
confusion matrix.
"""

@StyleDisplayMixin.style_plot
def __init__(
self,
confusion_matrix,
Expand All @@ -64,7 +61,8 @@ def __init__(
self.ax_ = None
self.text_ = None

def _plot_matplotlib(self, ax=None, *, cmap="Blues", colorbar=True, **kwargs):
@DisplayMixin.style_plot
def plot(self, ax=None, *, cmap="Blues", colorbar=True, **kwargs):
"""Plot the confusion matrix.

Parameters
Expand All @@ -87,6 +85,9 @@ def _plot_matplotlib(self, ax=None, *, cmap="Blues", colorbar=True, **kwargs):
self : ConfusionMatrixDisplay
Configured with the confusion matrix.
"""
return self._plot(ax=ax, cmap=cmap, colorbar=colorbar, **kwargs)

def _plot_matplotlib(self, ax=None, *, cmap="Blues", colorbar=True, **kwargs):
if self.normalize not in (None, "true", "pred", "all"):
raise ValueError(
"normalize must be one of None, 'true', 'pred', 'all'; "
Expand Down
Loading
Loading