Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
Aggregate,
PositiveLabel,
Scoring,
ScoringName,
YPlotData,
)
from skore._utils._accessor import _check_estimator_report_has_method
Expand Down Expand Up @@ -55,8 +54,7 @@ def summarize(
data_source: DataSource = "test",
X: ArrayLike | None = None,
y: ArrayLike | None = None,
scoring: Scoring | list[Scoring] | None = None,
scoring_names: ScoringName | list[ScoringName] | None = None,
scoring: Scoring | list[Scoring] | dict[str, Scoring] | None = None,
scoring_kwargs: dict[str, Any] | None = None,
pos_label: PositiveLabel | None = _DEFAULT,
indicator_favorability: bool = False,
Expand All @@ -83,8 +81,8 @@ def summarize(
New target on which to compute the metric. By default, we use the target
provided when creating the report.

scoring : str, callable, scorer or list of such instances, default=None
The metrics to report. The possible values (whether or not in a list) are:
scoring : str, callable, scorer, dict, or list of such instances, default=None
The metrics to report. The possible values are:

- if a string, either one of the built-in metrics or a scikit-learn scorer
name. You can get the possible list of string using
Expand All @@ -99,10 +97,10 @@ def summarize(
scorers as provided by :func:`sklearn.metrics.make_scorer`. In this case,
the metric favorability will only be displayed if it is given explicitly
via `make_scorer`'s `greater_is_better` parameter.

scoring_names : str, None or list of such instances, default=None
Used to overwrite the default scoring names in the report. If a list,
it should be of the same length as the `scoring` parameter.
- if a dict, the keys are used as metric names and the values are the
scoring functions (strings, callables, or scorers as described above).
- if a list, each element can be any of the above types (strings, callables,
scorers).

scoring_kwargs : dict, default=None
The keyword arguments to pass to the scoring functions.
Expand Down Expand Up @@ -152,6 +150,12 @@ class is set to the one provided when creating the report. If `None`,
if pos_label == _DEFAULT:
pos_label = self._parent.pos_label

# Handle dictionary scoring
scoring_names = None
if isinstance(scoring, dict):
scoring_names = list(scoring.keys())
scoring = list(scoring.values())

results = self._compute_metric_scores(
report_metric_name="summarize",
data_source=data_source,
Expand Down
44 changes: 14 additions & 30 deletions skore/src/skore/_sklearn/_estimator/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
_DEFAULT,
PositiveLabel,
Scoring,
ScoringName,
YPlotData,
)
from skore._utils._accessor import (
Expand Down Expand Up @@ -66,8 +65,7 @@ def summarize(
data_source: DataSource = "test",
X: ArrayLike | None = None,
y: ArrayLike | None = None,
scoring: Scoring | list[Scoring] | None = None,
scoring_names: ScoringName | list[ScoringName] | None = None,
scoring: Scoring | list[Scoring] | dict[str, Scoring] | None = None,
scoring_kwargs: dict[str, Any] | None = None,
pos_label: PositiveLabel | None = _DEFAULT,
indicator_favorability: bool = False,
Expand All @@ -92,8 +90,8 @@ def summarize(
New target on which to compute the metric. By default, we use the target
provided when creating the report.

scoring : str, callable, scorer or list of such instances, default=None
The metrics to report. The possible values (whether or not in a list) are:
scoring : str, callable, scorer, dict, or list of such instances, default=None
The metrics to report. The possible values are:

- if a string, either one of the built-in metrics or a scikit-learn scorer
name. You can get the possible list of string using
Expand All @@ -108,10 +106,10 @@ def summarize(
scorers as provided by :func:`sklearn.metrics.make_scorer`. In this case,
the metric favorability will only be displayed if it is given explicitly
via `make_scorer`'s `greater_is_better` parameter.

scoring_names : str, None or list of such instances, default=None
Used to overwrite the default scoring names in the report. It should be of
the same length as the `scoring` parameter.
- if a dict, the keys are used as metric names and the values are the
scoring functions (strings, callables, or scorers as described above).
- if a list, each element can be any of the above types (strings, callables,
scorers).

scoring_kwargs : dict, default=None
The keyword arguments to pass to the scoring functions.
Expand Down Expand Up @@ -164,12 +162,14 @@ class is set to the one provided when creating the report. If `None`,
if pos_label is _DEFAULT:
pos_label = self._parent.pos_label

if scoring is not None and not isinstance(scoring, list):
# Handle dictionary scoring
scoring_names = None
if isinstance(scoring, dict):
scoring_names = list(scoring.keys())
scoring = list(scoring.values())
elif scoring is not None and not isinstance(scoring, list):
scoring = [scoring]

if scoring_names is not None and not isinstance(scoring_names, list):
scoring_names = [scoring_names]

if data_source == "X_y":
# optimization of the hash computation to avoid recomputing it
# FIXME: we are still recomputing the hash for all the metrics that we
Expand All @@ -181,7 +181,6 @@ class is set to the one provided when creating the report. If `None`,
else:
data_source_hash = None

scoring_was_none = scoring is None
if scoring is None:
# Equivalent to _get_scorers_to_add
if self._parent._ml_task == "binary-classification":
Expand All @@ -196,22 +195,7 @@ class is set to the one provided when creating the report. If `None`,
scoring = ["_r2", "_rmse"]
scoring += ["_fit_time", "_predict_time"]

if scoring_names is not None and len(scoring_names) != len(scoring):
if scoring_was_none:
# we raise a better error message since we decide the default scores
raise ValueError(
"The `scoring_names` parameter should be of the same length as "
"the `scoring` parameter. In your case, `scoring` was set to None "
f"and you are using our default scores that are {len(scoring)}. "
f"The list is the following: {scoring}."
)
else:
raise ValueError(
"The `scoring_names` parameter should be of the same length as "
f"the `scoring` parameter. Got {len(scoring_names)} names for "
f"{len(scoring)} scoring functions."
)
elif scoring_names is None:
if scoring_names is None:
scoring_names = [None] * len(scoring)

scores = []
Expand Down
25 changes: 25 additions & 0 deletions skore/tests/unit/displays/metrics_summary/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,28 @@ def test_indicator_favorability(forest_binary_classification_data, aggregate):
assert indicator["Brier score"].tolist() == ["(↘︎)"]
assert indicator["Fit time (s)"].tolist() == ["(↘︎)"]
assert indicator["Predict time (s)"].tolist() == ["(↘︎)"]


def test_overwrite_scoring_names_with_dict_cross_validation(forest_multiclass_classification_data):
"""Test that we can overwrite the scoring names using dict scoring in CrossValidationReport."""
estimator, X, y = forest_multiclass_classification_data
report = CrossValidationReport(estimator, X, y, splitter=2)

scoring_dict = {
"Custom Precision": "_precision",
"Custom Recall": "_recall",
"Custom ROC AUC": "_roc_auc"
}

result = report.metrics.summarize(scoring=scoring_dict).frame()

# Check that custom names are used
result_index = (
result.index.get_level_values(0).tolist()
if isinstance(result.index, pd.MultiIndex)
else result.index.tolist()
)

assert "Custom Precision" in result_index
assert "Custom Recall" in result_index
assert "Custom ROC AUC" in result_index
21 changes: 14 additions & 7 deletions skore/tests/unit/displays/metrics_summary/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,16 +192,23 @@ def test_scoring_kwargs(


@pytest.mark.parametrize(
"fixture_name, scoring_names, expected_columns",
"fixture_name, scoring_dict, expected_columns",
[
(
"linear_regression_with_test",
["R2", "RMSE", "FIT_TIME", "PREDICT_TIME"],
{"R2": "r2", "RMSE": "rmse", "FIT_TIME": "_fit_time", "PREDICT_TIME": "_predict_time"},
["R2", "RMSE", "FIT_TIME", "PREDICT_TIME"],
),
(
"forest_multiclass_classification_with_test",
["Precision", "Recall", "ROC AUC", "Log Loss", "Fit Time", "Predict Time"],
{
"Precision": "_precision",
"Recall": "_recall",
"ROC AUC": "_roc_auc",
"Log Loss": "_log_loss",
"Fit Time": "_fit_time",
"Predict Time": "_predict_time"
},
[
"Precision",
"Precision",
Expand All @@ -219,13 +226,13 @@ def test_scoring_kwargs(
),
],
)
def test_overwrite_scoring_names(
request, fixture_name, scoring_names, expected_columns
def test_overwrite_scoring_names_with_dict(
request, fixture_name, scoring_dict, expected_columns
):
"""Test that we can overwrite the scoring names."""
"""Test that we can overwrite the scoring names using dict scoring."""
estimator, X_test, y_test = request.getfixturevalue(fixture_name)
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
result = report.metrics.summarize(scoring_names=scoring_names).frame()
result = report.metrics.summarize(scoring=scoring_dict).frame()
assert result.shape == (len(expected_columns), 1)

# Get level 0 names if MultiIndex, otherwise get column names
Expand Down
53 changes: 53 additions & 0 deletions skore/tests/unit/reports/estimator/metrics/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

import joblib
import numpy as np
import pandas as pd
import pytest
from sklearn.base import BaseEstimator
from sklearn.cluster import KMeans
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
get_scorer,
precision_score,
rand_score,
recall_score,
Expand Down Expand Up @@ -515,3 +517,54 @@ def test_roc_multiclass_requires_predict_proba(
report = EstimatorReport(classifier, X_test=X_test, y_test=y_test)
assert hasattr(report.metrics, "roc_auc")
report.metrics.roc_auc()


def test_summarize_scoring_dict(forest_binary_classification_with_test):
"""Test that scoring can be passed as a dictionary with custom names."""
estimator, X_test, y_test = forest_binary_classification_with_test
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)

# Test with dictionary scoring
scoring_dict = {
"Custom Accuracy": "accuracy",
"Custom Precision": "precision",
"Custom R2": get_scorer("neg_mean_absolute_error")
}

result = report.metrics.summarize(scoring=scoring_dict).frame()

# Check that custom names are used
assert "Custom Accuracy" in result.index
assert "Custom Precision" in result.index
assert "Custom R2" in result.index

# Verify the result structure
assert isinstance(result, pd.DataFrame)
assert len(result.index) >= 3 # At least our 3 custom metrics


def test_summarize_scoring_dict_with_callables(linear_regression_with_test):
"""Test that scoring dict works with callable functions."""
estimator, X_test, y_test = linear_regression_with_test
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)

def custom_metric(y_true, y_pred):
return np.mean(np.abs(y_true - y_pred))

scoring_dict = {
"R Squared": "r2",
"Custom MAE": custom_metric
}

result = report.metrics.summarize(
scoring=scoring_dict,
scoring_kwargs={"response_method": "predict"}
).frame()

# Check that custom names are used
assert "R Squared" in result.index
assert "Custom MAE" in result.index

# Verify the result structure
assert isinstance(result, pd.DataFrame)
assert len(result.index) == 2
Loading