diff --git a/examples/use_cases/plot_employee_salaries.py b/examples/use_cases/plot_employee_salaries.py index 260cbbbeae..83fc2e4a16 100644 --- a/examples/use_cases/plot_employee_salaries.py +++ b/examples/use_cases/plot_employee_salaries.py @@ -136,6 +136,25 @@ # %% hgbt_model_report.metrics.summarize().frame() +# %% +# Similarly to what we saw in the previous section, the +# :class:`skore.CrossValidationReport` also stores some information about the dataset +# used. + +# %% +data_display = hgbt_model_report.data.analyze() +data_display + +# %% +# The display obtained allows for a quick overview with the same HTML-based view +# as the :class:`skrub.TableReport` we have seen earlier. In addition, you can access +# a :meth:`skore.TableReportDisplay.plot` method to have a particular focus on one +# potential analysis. For instance, we can get a figure representing the correlation +# matrix of the dataset. + +# %% +data_display.plot(kind="corr") + # %% # We get the results from some statistical metrics aggregated over the cross-validation # splits as well as some performance metrics related to the time it took to train and @@ -153,25 +172,6 @@ # The favorability of each metric indicates whether the metric is better # when higher or lower. -# %% -# We also have access to some additional information regarding the dataset used for -# training and testing of the model, similar to what we have seen in the previous -# section. For example, let's check some information about the training dataset. - -# %% -train_data_display = hgbt_split_1.data.analyze(data_source="train") -train_data_display - -# %% -# The display obtained allows for a quick overview with the same HTML-based view -# as the :class:`skrub.TableReport` we have seen earlier. In addition, you can access -# a :meth:`skore.TableReportDisplay.plot` method to have a particular focus on one -# potential analysis. For instance, we can get a figure representing the correlation -# matrix of the training dataset. - -# %% -train_data_display.plot(kind="corr") - # %% # Linear model # ============ diff --git a/skore/src/skore/_externals/_skrub_compat.py b/skore/src/skore/_externals/_skrub_compat.py index 96af5d61ac..c5f7570fd1 100644 --- a/skore/src/skore/_externals/_skrub_compat.py +++ b/skore/src/skore/_externals/_skrub_compat.py @@ -10,6 +10,7 @@ if skrub_version < parse_version("0.6.0"): + tabular_pipeline = skrub.tabular_learner @dispatch def concat(*dataframes, axis=0): @@ -40,6 +41,9 @@ def _concat_polars(*dataframes, axis=0): sbd.concat = concat +else: + tabular_pipeline = skrub.tabular_pipeline + @dispatch def to_frame(col): diff --git a/skore/src/skore/_sklearn/_cross_validation/__init__.py b/skore/src/skore/_sklearn/_cross_validation/__init__.py index 397110767e..7b9c954b51 100644 --- a/skore/src/skore/_sklearn/_cross_validation/__init__.py +++ b/skore/src/skore/_sklearn/_cross_validation/__init__.py @@ -1,4 +1,5 @@ from skore._externals._pandas_accessors import _register_accessor +from skore._sklearn._cross_validation.data_accessor import _DataAccessor from skore._sklearn._cross_validation.feature_importance_accessor import ( _FeatureImportanceAccessor, ) @@ -8,6 +9,7 @@ ) _register_accessor("metrics", CrossValidationReport)(_MetricsAccessor) +_register_accessor("data", CrossValidationReport)(_DataAccessor) _register_accessor("feature_importance", CrossValidationReport)( _FeatureImportanceAccessor diff --git a/skore/src/skore/_sklearn/_cross_validation/data_accessor.py b/skore/src/skore/_sklearn/_cross_validation/data_accessor.py new file mode 100644 index 0000000000..4d3bcb7c94 --- /dev/null +++ b/skore/src/skore/_sklearn/_cross_validation/data_accessor.py @@ -0,0 +1,138 @@ +from typing import Literal + +import pandas as pd +from skrub import _dataframe as sbd + +from skore._externals._pandas_accessors import DirNamesMixin +from skore._sklearn._base import _BaseAccessor +from skore._sklearn._cross_validation.report import CrossValidationReport +from skore._sklearn._plot import TableReportDisplay + + +class _DataAccessor(_BaseAccessor[CrossValidationReport], DirNamesMixin): + def __init__(self, parent: CrossValidationReport) -> None: + super().__init__(parent) + + def _retrieve_data_as_frame( + self, + with_y: bool, + ): + """Retrieve data as DataFrame. + + Parameters + ---------- + with_y : bool + Whether we should check that `y` is not None. + + Returns + ------- + X : DataFrame + The input data. + + y : DataFrame or None + The target data. + """ + X = self._parent.X + y = self._parent.y + + if not sbd.is_dataframe(X): + X = pd.DataFrame(X, columns=[f"Feature {i}" for i in range(X.shape[1])]) + + if with_y: + if y is None: + raise ValueError("y is required when `with_y=True`.") + + if isinstance(y, pd.Series): + name = y.name if y.name is not None else "Target" + y = y.to_frame(name=name) + elif not sbd.is_dataframe(y): + if y.ndim == 1: + columns = ["Target"] + else: + columns = [f"Target {i}" for i in range(y.shape[1])] + y = pd.DataFrame(y, columns=columns) + + return X, y + + def analyze( + self, + with_y: bool = True, + subsample: int | None = None, + subsample_strategy: Literal["head", "random"] = "head", + seed: int | None = None, + ) -> TableReportDisplay: + """Plot dataset statistics. + + Parameters + ---------- + with_y : bool, default=True + Whether to include the target variable in the analysis. If True, the target + variable is concatenated horizontally to the features. + + subsample : int, default=None + The number of points to subsample the dataframe hold by the display, using + the strategy set by ``subsample_strategy``. It must be a strictly positive + integer. If ``None``, no subsampling is applied. + + subsample_strategy : {'head', 'random'}, default='head', + The strategy used to subsample the dataframe hold by the display. It only + has an effect when ``subsample`` is not None. + + - If ``'head'``: subsample by taking the ``subsample`` first points of the + dataframe, similar to Pandas: ``df.head(n)``. + - If ``'random'``: randomly subsample the dataframe by using a uniform + distribution. The random seed is controlled by ``random_state``. + + seed : int, default=None + The random seed to use when randomly subsampling. It only has an effect when + ``subsample`` is not None and ``subsample_strategy='random'``. + + Returns + ------- + TableReportDisplay + A display object containing the dataset statistics and plots. + + Examples + -------- + >>> from sklearn.datasets import load_breast_cancer + >>> from sklearn.linear_model import LogisticRegression + >>> from skore import CrossValidationReport + >>> X, y = load_breast_cancer(return_X_y=True) + >>> classifier = LogisticRegression(max_iter=10_000) + >>> report = CrossValidationReport(classifier, X=X, y=y, pos_label=1) + >>> report.data.analyze().frame() + """ + if subsample_strategy not in (subsample_strategy_options := ("head", "random")): + raise ValueError( + f"'subsample_strategy' options are {subsample_strategy_options!r}, got " + f"{subsample_strategy}." + ) + + X, y = self._retrieve_data_as_frame(with_y) + + df = sbd.concat(X, y, axis=1) if with_y else X + + if subsample: + if subsample_strategy == "head": + df = sbd.head(df, subsample) + else: # subsample_strategy == "random": + df = sbd.sample(df, subsample, seed=seed) + + return TableReportDisplay._compute_data_for_display(df) + + #################################################################################### + # Methods related to the help tree + #################################################################################### + + def _format_method_name(self, name: str) -> str: + return f"{name}(...)".ljust(29) + + def _get_help_panel_title(self) -> str: + return "[bold cyan]Available data methods[/bold cyan]" + + def _get_help_tree_title(self) -> str: + return "[bold cyan]report.data[/bold cyan]" + + def __repr__(self) -> str: + """Return a string representation using rich.""" + return self._rich_repr(class_name="skore.CrossValidationReport.data") diff --git a/skore/src/skore/project/project.py b/skore/src/skore/project/project.py index 765cf60583..f64fad07f1 100644 --- a/skore/src/skore/project/project.py +++ b/skore/src/skore/project/project.py @@ -6,7 +6,7 @@ from importlib.metadata import entry_points from typing import Any, Literal -from skore._sklearn._estimator.report import EstimatorReport +from skore import EstimatorReport from skore.project.summary import Summary @@ -95,7 +95,7 @@ class Project: >>> from sklearn.datasets import make_classification, make_regression >>> from sklearn.linear_model import LinearRegression, LogisticRegression >>> from sklearn.model_selection import train_test_split - >>> from skore._sklearn import EstimatorReport + >>> from skore import EstimatorReport >>> >>> X, y = make_classification(random_state=42) >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) diff --git a/skore/tests/unit/displays/table_report/test_common.py b/skore/tests/unit/displays/table_report/test_common.py new file mode 100644 index 0000000000..159a7a7a32 --- /dev/null +++ b/skore/tests/unit/displays/table_report/test_common.py @@ -0,0 +1,163 @@ +import numpy as np +import pandas as pd +import pytest +from matplotlib.collections import QuadMesh +from sklearn.datasets import make_regression +from skore import CrossValidationReport, EstimatorReport, train_test_split +from skore._externals._skrub_compat import tabular_pipeline +from skore._sklearn._plot.data.table_report import ( + _compute_contingency_table, + _resize_categorical_axis, + _truncate_top_k_categories, +) + + +@pytest.fixture +def X_y(): + X, y = make_regression(n_samples=100, n_features=5, random_state=42) + X = pd.DataFrame(X, columns=[f"Feature_{i}" for i in range(5)]) + y = pd.Series(y, name="Target_") + return X, y + + +@pytest.fixture +def estimator_report(X_y): + X, y = X_y + split_data = train_test_split(X, y, random_state=0, as_dict=True) + return EstimatorReport(tabular_pipeline("regressor"), **split_data) + + +@pytest.fixture +def cross_validation_report(X_y): + X, y = X_y + return CrossValidationReport(tabular_pipeline("regressor"), X=X, y=y) + + +@pytest.fixture(params=["estimator_report", "cross_validation_report"]) +def display(request): + report = request.getfixturevalue(request.param) + return report.data.analyze() + + +@pytest.mark.parametrize("dtype", ["category", "object"]) +@pytest.mark.parametrize("other_label", ["other", "xxx"]) +def test_truncate_top_k_categories(dtype, other_label): + """Check the behaviour of `_truncate_top_k_categories` when `col` is a categorical + column.""" + col = pd.Series( + ["a", "a", "b", "b", "b", "c", "c", "c", "c", "c", "d", "e", np.nan, np.nan], + dtype=dtype, + ) + expected_col = pd.Series( + [ + "a", + "a", + "b", + "b", + "b", + "c", + "c", + "c", + "c", + "c", + other_label, + other_label, + np.nan, + np.nan, + ], + dtype=dtype, + ) + truncated_col = _truncate_top_k_categories(col, k=3, other_label=other_label) + pd.testing.assert_series_equal(truncated_col, expected_col) + + +@pytest.mark.parametrize("is_x_axis", [True, False]) +def test_resize_categorical_axis(pyplot, is_x_axis): + """Check the behaviour of the `_resize_categorical_axis` function.""" + figure, ax = pyplot.subplots(figsize=(10, 10)) + _resize_categorical_axis( + figure=figure, + ax=ax, + n_categories=1, + is_x_axis=is_x_axis, + size_per_category=0.5, + ) + + fig_width, fig_height = figure.get_size_inches() + if is_x_axis: + assert 0.5 < fig_width < 1.0 + assert 10.0 < fig_height < 13.0 + else: + assert 0.5 < fig_height < 1.0 + assert 10.0 < fig_width < 13.0 + + +@pytest.mark.parametrize("col", [None, pd.Series(range(10))]) +def test_truncate_top_k_categories_return_as_is(col): + """Check the behaviour of `_truncate_top_k_categories` when `col` is None or + numeric where no changes are made.""" + assert _truncate_top_k_categories(col, k=3) is col + + +def test_corr_plot(pyplot, estimator_report): + display = estimator_report.data.analyze(data_source="train") + display.plot(kind="corr") + assert isinstance(display.ax_.collections[0], QuadMesh) + assert len(display.ax_.get_xticklabels()) == 6 + assert len(display.ax_.get_yticklabels()) == 6 + assert display.ax_.title.get_text() == "Cramer's V Correlation" + + +def test_repr(display): + assert repr(display) == "" + + +def test_compute_contingency_table_error(): + """Check that we raise an error when the series x and y don't have a name.""" + series = pd.Series(["a", "a", "b", "b", "b", "c", "c", "c", "c", "c", "d", "e"]) + err_msg = "The series x and y must have a name." + with pytest.raises(ValueError, match=err_msg): + _compute_contingency_table(x=series, y=series, hue=None, k=1) + + +@pytest.mark.parametrize("dtype", ["category", "object"]) +def test_compute_contingency_table(dtype): + """Check the behaviour of the `_compute_contingency_table` function.""" + x = pd.Series( + ["a", "a", "b", "b", "b", "c", "c", "c", "c", "c", "d", "e"], + name="x", + dtype=dtype, + ) + y = pd.Series( + ["a", "a", "b", "b", "b", "c", "c", "c", "c", "c", "d", "e"], + name="y", + dtype=dtype, + ) + contingency_table = _compute_contingency_table(x, y, hue=None, k=100) + assert contingency_table.sum().sum() == len(x) + assert sorted(contingency_table.columns.tolist()) == sorted(x.unique().tolist()) + assert sorted(contingency_table.index.tolist()) == sorted(y.unique().tolist()) + + hue = pd.Series(np.ones_like(x) * 2.0) + contingency_table = _compute_contingency_table(x, y, hue, k=100) + assert contingency_table.sum().sum() == pytest.approx(x.unique().size * 2) + assert sorted(contingency_table.columns.tolist()) == sorted(x.unique().tolist()) + assert sorted(contingency_table.index.tolist()) == sorted(y.unique().tolist()) + + contingency_table = _compute_contingency_table(x, y, hue=None, k=2) + assert contingency_table.index.tolist() == ["b", "c"] + assert contingency_table.columns.tolist() == ["b", "c"] + assert contingency_table.sum().sum() == 8 + + contingency_table = _compute_contingency_table(x, y, hue=hue, k=2) + assert contingency_table.index.tolist() == ["b", "c"] + assert contingency_table.columns.tolist() == ["b", "c"] + assert contingency_table.sum().sum() == 4 + + +def test_html_repr(display, X_y): + """Check the HTML representation of the `TableReportDisplay`.""" + str_html = display._repr_html_() + X, _ = X_y + assert all(col in str_html for col in X.columns) + assert "" - - -def test_html_repr(estimator_report): - """Check the HTML representation of the `TableReportDisplay`.""" - display = estimator_report.data.analyze(data_source="train") - str_html = display._repr_html_() - for col in estimator_report.X_train.columns: - assert col in str_html - - assert "