Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
546c92a
start cross-val data accessor
MarieSacksick Aug 4, 2025
860fe58
add tests and adapt code
MarieSacksick Aug 5, 2025
b784c7f
update tabular learner to pipeline, new in skrub 0.6
MarieSacksick Aug 5, 2025
d336a23
fix doctest
MarieSacksick Aug 5, 2025
0279610
update tabular learner to pipeline, new in skrub 0.6
MarieSacksick Aug 5, 2025
30b5b87
fix doctest
MarieSacksick Aug 5, 2025
0566330
add tests
MarieSacksick Aug 5, 2025
ab07bcf
split tests common and estimator for display
MarieSacksick Aug 5, 2025
18b4ed9
shorten some tests
MarieSacksick Aug 6, 2025
ae03d63
add test missing y
MarieSacksick Aug 6, 2025
73af646
fix repr
MarieSacksick Aug 6, 2025
ca2d6e2
test repr of data accessor instead of display
MarieSacksick Aug 7, 2025
f2945a6
factorize fixtures
auguste-probabl Aug 13, 2025
1fb6880
remove find_estimators
auguste-probabl Aug 13, 2025
cb9f729
vendor tabular_pipeline for skrub compatibility
auguste-probabl Aug 13, 2025
16419ce
Update skore/tests/unit/displays/table_report/test_cross_validation.py
auguste-probabl Sep 1, 2025
040351d
parametrize tests
MarieSacksick Sep 2, 2025
0ad5d87
make test name more explicit
MarieSacksick Sep 2, 2025
29ae880
commonize html repr
MarieSacksick Sep 2, 2025
960adc4
add data accessor to rst file for cv
MarieSacksick Sep 5, 2025
ece0561
linting
MarieSacksick Sep 8, 2025
c7f848b
Adapt tests
MarieSacksick Sep 8, 2025
b3c529b
enhance tests
MarieSacksick Sep 8, 2025
e0eaec6
complete data acessor docs
MarieSacksick Sep 9, 2025
a0fb0c5
add data to accessor in docs
MarieSacksick Sep 9, 2025
bfb65f7
adapt example to use the new data accessor in cv
MarieSacksick Sep 9, 2025
19504ba
remove comment
glemaitre Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions examples/use_cases/plot_employee_salaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,25 @@
# %%
hgbt_model_report.metrics.summarize().frame()

# %%
# Similarly to what we saw in the previous section, the
# :class:`skore.CrossValidationReport` also stores some information about the dataset
# used.

# %%
data_display = hgbt_model_report.data.analyze()
data_display

# %%
# The display obtained allows for a quick overview with the same HTML-based view
# as the :class:`skrub.TableReport` we have seen earlier. In addition, you can access
# a :meth:`skore.TableReportDisplay.plot` method to have a particular focus on one
# potential analysis. For instance, we can get a figure representing the correlation
# matrix of the dataset.

# %%
data_display.plot(kind="corr")

# %%
# We get the results from some statistical metrics aggregated over the cross-validation
# splits as well as some performance metrics related to the time it took to train and
Expand All @@ -153,25 +172,6 @@
# The favorability of each metric indicates whether the metric is better
# when higher or lower.

# %%
# We also have access to some additional information regarding the dataset used for
# training and testing of the model, similar to what we have seen in the previous
# section. For example, let's check some information about the training dataset.

# %%
train_data_display = hgbt_split_1.data.analyze(data_source="train")
train_data_display

# %%
# The display obtained allows for a quick overview with the same HTML-based view
# as the :class:`skrub.TableReport` we have seen earlier. In addition, you can access
# a :meth:`skore.TableReportDisplay.plot` method to have a particular focus on one
# potential analysis. For instance, we can get a figure representing the correlation
# matrix of the training dataset.

# %%
train_data_display.plot(kind="corr")

# %%
# Linear model
# ============
Expand Down
4 changes: 4 additions & 0 deletions skore/src/skore/_externals/_skrub_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


if skrub_version < parse_version("0.6.0"):
tabular_pipeline = skrub.tabular_learner

@dispatch
def concat(*dataframes, axis=0):
Expand Down Expand Up @@ -40,6 +41,9 @@ def _concat_polars(*dataframes, axis=0):

sbd.concat = concat

else:
tabular_pipeline = skrub.tabular_pipeline


@dispatch
def to_frame(col):
Expand Down
2 changes: 2 additions & 0 deletions skore/src/skore/_sklearn/_cross_validation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from skore._externals._pandas_accessors import _register_accessor
from skore._sklearn._cross_validation.data_accessor import _DataAccessor
from skore._sklearn._cross_validation.feature_importance_accessor import (
_FeatureImportanceAccessor,
)
Expand All @@ -8,6 +9,7 @@
)

_register_accessor("metrics", CrossValidationReport)(_MetricsAccessor)
_register_accessor("data", CrossValidationReport)(_DataAccessor)

_register_accessor("feature_importance", CrossValidationReport)(
_FeatureImportanceAccessor
Expand Down
138 changes: 138 additions & 0 deletions skore/src/skore/_sklearn/_cross_validation/data_accessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Literal

import pandas as pd
from skrub import _dataframe as sbd

from skore._externals._pandas_accessors import DirNamesMixin
from skore._sklearn._base import _BaseAccessor
from skore._sklearn._cross_validation.report import CrossValidationReport
from skore._sklearn._plot import TableReportDisplay


class _DataAccessor(_BaseAccessor[CrossValidationReport], DirNamesMixin):
def __init__(self, parent: CrossValidationReport) -> None:
super().__init__(parent)

def _retrieve_data_as_frame(
self,
with_y: bool,
):
"""Retrieve data as DataFrame.

Parameters
----------
with_y : bool
Whether we should check that `y` is not None.

Returns
-------
X : DataFrame
The input data.

y : DataFrame or None
The target data.
"""
X = self._parent.X
y = self._parent.y

if not sbd.is_dataframe(X):
X = pd.DataFrame(X, columns=[f"Feature {i}" for i in range(X.shape[1])])

if with_y:
if y is None:
raise ValueError("y is required when `with_y=True`.")

if isinstance(y, pd.Series):
name = y.name if y.name is not None else "Target"
y = y.to_frame(name=name)
elif not sbd.is_dataframe(y):
if y.ndim == 1:
columns = ["Target"]
else:
columns = [f"Target {i}" for i in range(y.shape[1])]
y = pd.DataFrame(y, columns=columns)

return X, y

def analyze(
self,
with_y: bool = True,
subsample: int | None = None,
subsample_strategy: Literal["head", "random"] = "head",
seed: int | None = None,
) -> TableReportDisplay:
"""Plot dataset statistics.

Parameters
----------
with_y : bool, default=True
Whether to include the target variable in the analysis. If True, the target
variable is concatenated horizontally to the features.

subsample : int, default=None
The number of points to subsample the dataframe hold by the display, using
the strategy set by ``subsample_strategy``. It must be a strictly positive
integer. If ``None``, no subsampling is applied.

subsample_strategy : {'head', 'random'}, default='head',
The strategy used to subsample the dataframe hold by the display. It only
has an effect when ``subsample`` is not None.

- If ``'head'``: subsample by taking the ``subsample`` first points of the
dataframe, similar to Pandas: ``df.head(n)``.
- If ``'random'``: randomly subsample the dataframe by using a uniform
distribution. The random seed is controlled by ``random_state``.

seed : int, default=None
The random seed to use when randomly subsampling. It only has an effect when
``subsample`` is not None and ``subsample_strategy='random'``.

Returns
-------
TableReportDisplay
A display object containing the dataset statistics and plots.

Examples
--------
>>> from sklearn.datasets import load_breast_cancer
>>> from sklearn.linear_model import LogisticRegression
>>> from skore import CrossValidationReport
>>> X, y = load_breast_cancer(return_X_y=True)
>>> classifier = LogisticRegression(max_iter=10_000)
>>> report = CrossValidationReport(classifier, X=X, y=y, pos_label=1)
>>> report.data.analyze().frame()
"""
if subsample_strategy not in (subsample_strategy_options := ("head", "random")):
raise ValueError(
f"'subsample_strategy' options are {subsample_strategy_options!r}, got "
f"{subsample_strategy}."
)

X, y = self._retrieve_data_as_frame(with_y)

df = sbd.concat(X, y, axis=1) if with_y else X

if subsample:
if subsample_strategy == "head":
df = sbd.head(df, subsample)
else: # subsample_strategy == "random":
df = sbd.sample(df, subsample, seed=seed)

return TableReportDisplay._compute_data_for_display(df)

####################################################################################
# Methods related to the help tree
####################################################################################

def _format_method_name(self, name: str) -> str:
return f"{name}(...)".ljust(29)

def _get_help_panel_title(self) -> str:
return "[bold cyan]Available data methods[/bold cyan]"

def _get_help_tree_title(self) -> str:
return "[bold cyan]report.data[/bold cyan]"

def __repr__(self) -> str:
"""Return a string representation using rich."""
return self._rich_repr(class_name="skore.CrossValidationReport.data")
4 changes: 2 additions & 2 deletions skore/src/skore/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from importlib.metadata import entry_points
from typing import Any, Literal

from skore._sklearn._estimator.report import EstimatorReport
from skore import EstimatorReport
from skore.project.summary import Summary


Expand Down Expand Up @@ -95,7 +95,7 @@ class Project:
>>> from sklearn.datasets import make_classification, make_regression
>>> from sklearn.linear_model import LinearRegression, LogisticRegression
>>> from sklearn.model_selection import train_test_split
>>> from skore._sklearn import EstimatorReport
>>> from skore import EstimatorReport
>>>
>>> X, y = make_classification(random_state=42)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
Expand Down
Loading
Loading