Skip to content
49 changes: 47 additions & 2 deletions skore/src/skore/_sklearn/_estimator/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, parent: EstimatorReport) -> None:
def summarize(
self,
*,
data_source: DataSource = "test",
data_source: DataSource | Literal["all"] = "test",
X: ArrayLike | None = None,
y: ArrayLike | None = None,
scoring: Scoring | list[Scoring] | dict[str, Scoring] | None = None,
Expand All @@ -75,12 +75,14 @@ def summarize(

Parameters
----------
data_source : {"test", "train", "X_y"}, default="test"
data_source : {"test", "train", "X_y", "all"}, default="test"
The data source to use.

- "test" : use the test set provided when creating the report.
- "train" : use the train set provided when creating the report.
- "X_y" : use the provided `X` and `y` to compute the metric.
- "all" : use both the train and test sets to compute the metrics and
present them side-by-side.

X : array-like of shape (n_samples, n_features), default=None
New data on which to compute the metric. By default, we use the validation
Expand Down Expand Up @@ -159,7 +161,50 @@ class is set to the one provided when creating the report. If `None`,
LogisticRegression Favorability
Metric Label / Average
F1 Score 1 0.96... (↗︎)
>>> report.metrics.summarize(
... indicator_favorability=True,
... data_source="all"
... ).frame().drop(["Fit time (s)", "Predict time (s)"])
LogisticRegression (train) LogisticRegression (test) Favorability
Metric
Precision 0.96... 0.98... (↗︎)
Recall 0.97... 0.93... (↗︎)
ROC AUC 0.99... 0.99... (↗︎)
Brier score 0.02... 0.03... (↘︎)
>>> # Using scikit-learn metrics
>>> report.metrics.summarize(
... scoring=["f1"],
... indicator_favorability=True,
... ).frame()
LogisticRegression Favorability
Metric Label / Average
F1 Score 1 0.96... (↗︎)
"""
if data_source == "all":
train_summary = self.summarize(
data_source="train",
scoring=scoring,
scoring_kwargs=scoring_kwargs,
pos_label=pos_label,
indicator_favorability=False,
flat_index=flat_index,
)
test_summary = self.summarize(
data_source="test",
scoring=scoring,
scoring_kwargs=scoring_kwargs,
pos_label=pos_label,
indicator_favorability=indicator_favorability,
flat_index=flat_index,
)
# Add suffix to the dataframes to distinguish train and test.
train_df = train_summary.frame().add_suffix(" (train)")
test_df = test_summary.frame().add_suffix(" (test)")
combined = pd.concat([train_df, test_df], axis=1).rename(
columns={"Favorability (test)": "Favorability"}
)
return MetricsSummaryDisplay(summarize_data=combined)

if pos_label is _DEFAULT:
pos_label = self._parent.pos_label

Expand Down
40 changes: 40 additions & 0 deletions skore/tests/unit/reports/estimator/metrics/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pandas as pd
import pytest
from pandas.testing import assert_series_equal
from sklearn.base import BaseEstimator
from sklearn.cluster import KMeans
from sklearn.datasets import make_classification
Expand Down Expand Up @@ -127,6 +128,45 @@ def test_summarize_regression(linear_regression_with_test, metric):
assert report._cache != {}


def test_summarize_data_source_all(forest_binary_classification_data):
"""Check the behaviour of `summarize` with `data_source="all"`."""
estimator, X, y = forest_binary_classification_data
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

report = EstimatorReport(
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
)

result_train = report.metrics.summarize(data_source="train").frame()
result_test = report.metrics.summarize(data_source="test").frame()
result_all = report.metrics.summarize(data_source="all").frame()

assert result_all.columns.tolist() == [
"RandomForestClassifier (train)",
"RandomForestClassifier (test)",
]
assert_series_equal(
result_all["RandomForestClassifier (train)"],
result_train["RandomForestClassifier"],
check_names=False,
)
assert_series_equal(
result_all["RandomForestClassifier (test)"],
result_test["RandomForestClassifier"],
check_names=False,
)

# By default,
result_all = report.metrics.summarize(
data_source="all", indicator_favorability=True
).frame()
assert result_all.columns.tolist() == [
"RandomForestClassifier (train)",
"RandomForestClassifier (test)",
"Favorability",
]


def test_interaction_cache_metrics(
linear_regression_multioutput_with_test,
):
Expand Down
Loading