Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions examples/model_evaluation/plot_estimator_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,11 @@ def operational_decision_cost(y_true, y_pred, amount):
# that we can compute some additional metrics without having to recompute the
# the predictions.
report.metrics.summarize(
scoring=["precision", "recall", operational_decision_cost],
scoring_names=["Precision", "Recall", "Operational Decision Cost"],
scoring={
"Precision": "precision",
"Recall": "recall",
"Operational Decision Cost": operational_decision_cost,
},
scoring_kwargs={"amount": amount, "response_method": "predict"},
).frame()

Expand All @@ -309,8 +312,10 @@ def operational_decision_cost(y_true, y_pred, amount):
operational_decision_cost, response_method="predict", amount=amount
)
report.metrics.summarize(
scoring=[f1_scorer, operational_decision_cost_scorer],
scoring_names=["F1 Score", "Operational Decision Cost"],
scoring={
"F1 Score": f1_scorer,
"Operational Decision Cost": operational_decision_cost_scorer,
},
).frame()

# %%
Expand Down
9 changes: 2 additions & 7 deletions examples/use_cases/plot_employee_salaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,15 +308,10 @@ def periodic_spline_transformer(period, n_splines=None, degree=3):
# %%
from sklearn.metrics import get_scorer

scoring = ["r2", "rmse", get_scorer("neg_mean_absolute_error")]
scoring = {"R²": "r2", "RMSE": "rmse", "MAE": get_scorer("neg_mean_absolute_error")}
scoring_kwargs = {"response_method": "predict"}
scoring_names = ["R²", "RMSE", "MAE"]

comparator.metrics.summarize(
scoring=scoring,
scoring_kwargs=scoring_kwargs,
scoring_names=scoring_names,
).frame()
comparator.metrics.summarize(scoring=scoring, scoring_kwargs=scoring_kwargs).frame()

# %%
# Finally, we can even get a deeper understanding by analyzing each split in the
Expand Down
16 changes: 5 additions & 11 deletions skore/src/skore/_sklearn/_comparison/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
Aggregate,
PositiveLabel,
Scoring,
ScoringName,
YPlotData,
)
from skore._utils._accessor import (
Expand Down Expand Up @@ -57,8 +56,7 @@ def summarize(
data_source: DataSource = "test",
X: ArrayLike | None = None,
y: ArrayLike | None = None,
scoring: Scoring | list[Scoring] | None = None,
scoring_names: ScoringName | list[ScoringName] | None = None,
scoring: Scoring | list[Scoring] | dict[str, Scoring] | None = None,
scoring_kwargs: dict[str, Any] | None = None,
pos_label: PositiveLabel | None = _DEFAULT,
indicator_favorability: bool = False,
Expand All @@ -84,7 +82,8 @@ def summarize(
New target on which to compute the metric. By default, we use the target
provided when creating the report.

scoring : str, callable, scorer or list of such instances, default=None
scoring : str, callable, scorer, or list of such instances or dict of such \
instances, default=None
The metrics to report. The possible values (whether or not in a list) are:

- if a string, either one of the built-in metrics or a scikit-learn scorer
Expand All @@ -101,10 +100,6 @@ def summarize(
the metric favorability will only be displayed if it is given explicitly
via `make_scorer`'s `greater_is_better` parameter.

scoring_names : str, None or list of such instances, default=None
Used to overwrite the default scoring names in the report. It should be of
the same length as the ``scoring`` parameter.

scoring_kwargs : dict, default=None
The keyword arguments to pass to the scoring functions.

Expand Down Expand Up @@ -164,7 +159,6 @@ class is set to the one provided when creating the report. If `None`,
scoring=scoring,
pos_label=pos_label,
scoring_kwargs=scoring_kwargs,
scoring_names=scoring_names,
indicator_favorability=indicator_favorability,
aggregate=aggregate,
)
Expand Down Expand Up @@ -1125,12 +1119,12 @@ def custom_metric(
response_method=response_method,
**kwargs,
)
scoring = {metric_name: scorer} if metric_name is not None else [scorer]
return self.summarize(
scoring=[scorer],
scoring=scoring,
data_source=data_source,
X=X,
y=y,
scoring_names=[metric_name] if metric_name is not None else None,
aggregate=aggregate,
).frame()

Expand Down
22 changes: 10 additions & 12 deletions skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
Aggregate,
PositiveLabel,
Scoring,
ScoringName,
YPlotData,
)
from skore._utils._accessor import _check_estimator_report_has_method
Expand Down Expand Up @@ -55,8 +54,7 @@ def summarize(
data_source: DataSource = "test",
X: ArrayLike | None = None,
y: ArrayLike | None = None,
scoring: Scoring | list[Scoring] | None = None,
scoring_names: ScoringName | list[ScoringName] | None = None,
scoring: Scoring | list[Scoring] | dict[str, Scoring] | None = None,
scoring_kwargs: dict[str, Any] | None = None,
pos_label: PositiveLabel | None = _DEFAULT,
indicator_favorability: bool = False,
Expand All @@ -83,8 +81,9 @@ def summarize(
New target on which to compute the metric. By default, we use the target
provided when creating the report.

scoring : str, callable, scorer or list of such instances, default=None
The metrics to report. The possible values (whether or not in a list) are:
scoring : str, callable, scorer, or list of such instances or dict of such \
instances, default=None
The metrics to report. The possible values are:

- if a string, either one of the built-in metrics or a scikit-learn scorer
name. You can get the possible list of string using
Expand All @@ -99,10 +98,10 @@ def summarize(
scorers as provided by :func:`sklearn.metrics.make_scorer`. In this case,
the metric favorability will only be displayed if it is given explicitly
via `make_scorer`'s `greater_is_better` parameter.

scoring_names : str, None or list of such instances, default=None
Used to overwrite the default scoring names in the report. If a list,
it should be of the same length as the `scoring` parameter.
- if a dict, the keys are used as metric names and the values are the
scoring functions (strings, callables, or scorers as described above).
- if a list, each element can be any of the above types (strings, callables,
scorers).

scoring_kwargs : dict, default=None
The keyword arguments to pass to the scoring functions.
Expand Down Expand Up @@ -161,7 +160,6 @@ class is set to the one provided when creating the report. If `None`,
scoring=scoring,
pos_label=pos_label,
scoring_kwargs=scoring_kwargs,
scoring_names=scoring_names,
indicator_favorability=indicator_favorability,
)
if flat_index:
Expand Down Expand Up @@ -1035,13 +1033,13 @@ def custom_metric(
response_method=response_method,
**kwargs,
)
scoring = {metric_name: scorer} if metric_name is not None else [scorer]
return self.summarize(
scoring=[scorer],
scoring=scoring,
data_source=data_source,
X=X,
y=y,
aggregate=aggregate,
scoring_names=[metric_name] if metric_name is not None else None,
pos_label=pos_label,
).frame()

Expand Down
47 changes: 16 additions & 31 deletions skore/src/skore/_sklearn/_estimator/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
_DEFAULT,
PositiveLabel,
Scoring,
ScoringName,
YPlotData,
)
from skore._utils._accessor import (
Expand Down Expand Up @@ -66,8 +65,7 @@ def summarize(
data_source: DataSource = "test",
X: ArrayLike | None = None,
y: ArrayLike | None = None,
scoring: Scoring | list[Scoring] | None = None,
scoring_names: ScoringName | list[ScoringName] | None = None,
scoring: Scoring | list[Scoring] | dict[str, Scoring] | None = None,
scoring_kwargs: dict[str, Any] | None = None,
pos_label: PositiveLabel | None = _DEFAULT,
indicator_favorability: bool = False,
Expand All @@ -92,8 +90,9 @@ def summarize(
New target on which to compute the metric. By default, we use the target
provided when creating the report.

scoring : str, callable, scorer or list of such instances, default=None
The metrics to report. The possible values (whether or not in a list) are:
scoring : str, callable, scorer, or list of such instances or dict of such \
instances, default=None
The metrics to report. The possible values are:

- if a string, either one of the built-in metrics or a scikit-learn scorer
name. You can get the possible list of string using
Expand All @@ -108,10 +107,10 @@ def summarize(
scorers as provided by :func:`sklearn.metrics.make_scorer`. In this case,
the metric favorability will only be displayed if it is given explicitly
via `make_scorer`'s `greater_is_better` parameter.

scoring_names : str, None or list of such instances, default=None
Used to overwrite the default scoring names in the report. It should be of
the same length as the `scoring` parameter.
- if a dict, the keys are used as metric names and the values are the
scoring functions (strings, callables, or scorers as described above).
- if a list, each element can be any of the above types (strings, callables,
scorers).

scoring_kwargs : dict, default=None
The keyword arguments to pass to the scoring functions.
Expand Down Expand Up @@ -164,12 +163,14 @@ class is set to the one provided when creating the report. If `None`,
if pos_label is _DEFAULT:
pos_label = self._parent.pos_label

if scoring is not None and not isinstance(scoring, list):
# Handle dictionary scoring
scoring_names = None
if isinstance(scoring, dict):
scoring_names = list(scoring.keys())
scoring = list(scoring.values())
elif scoring is not None and not isinstance(scoring, list):
scoring = [scoring]

if scoring_names is not None and not isinstance(scoring_names, list):
scoring_names = [scoring_names]

if data_source == "X_y":
# optimization of the hash computation to avoid recomputing it
# FIXME: we are still recomputing the hash for all the metrics that we
Expand All @@ -181,7 +182,6 @@ class is set to the one provided when creating the report. If `None`,
else:
data_source_hash = None

scoring_was_none = scoring is None
if scoring is None:
# Equivalent to _get_scorers_to_add
if self._parent._ml_task == "binary-classification":
Expand All @@ -196,23 +196,8 @@ class is set to the one provided when creating the report. If `None`,
scoring = ["_r2", "_rmse"]
scoring += ["_fit_time", "_predict_time"]

if scoring_names is not None and len(scoring_names) != len(scoring):
if scoring_was_none:
# we raise a better error message since we decide the default scores
raise ValueError(
"The `scoring_names` parameter should be of the same length as "
"the `scoring` parameter. In your case, `scoring` was set to None "
f"and you are using our default scores that are {len(scoring)}. "
f"The list is the following: {scoring}."
)
else:
raise ValueError(
"The `scoring_names` parameter should be of the same length as "
f"the `scoring` parameter. Got {len(scoring_names)} names for "
f"{len(scoring)} scoring functions."
)
elif scoring_names is None:
scoring_names = [None] * len(scoring)
if scoring_names is None:
scoring_names = [None] * len(scoring) # type: ignore

scores = []
favorability_indicator = []
Expand Down
50 changes: 45 additions & 5 deletions skore/tests/unit/displays/metrics_summary/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,16 +313,28 @@ def test_scoring_kwargs_multi_class(


@pytest.mark.parametrize(
"fixture_name, scoring_names, expected_index",
"fixture_name, scoring, expected_index",
[
(
"linear_regression_data",
["R2", "RMSE", "FIT_TIME", "PREDICT_TIME"],
{
"R2": "r2",
"RMSE": "rmse",
"FIT_TIME": "_fit_time",
"PREDICT_TIME": "_predict_time",
},
["R2", "RMSE", "FIT_TIME", "PREDICT_TIME"],
),
(
"forest_multiclass_classification_data",
["Precision", "Recall", "ROC AUC", "Log Loss", "Fit Time", "Predict Time"],
{
"Precision": "_precision",
"Recall": "_recall",
"ROC AUC": "_roc_auc",
"Log Loss": "_log_loss",
"Fit Time": "_fit_time",
"Predict Time": "_predict_time",
},
[
"Precision",
"Precision",
Expand All @@ -340,11 +352,11 @@ def test_scoring_kwargs_multi_class(
),
],
)
def test_overwrite_scoring_names(request, fixture_name, scoring_names, expected_index):
def test_overwrite_scoring_names(request, fixture_name, scoring, expected_index):
"""Test that we can overwrite the scoring names in `MetricsSummaryDisplay`."""
estimator, X, y = request.getfixturevalue(fixture_name)
report = CrossValidationReport(estimator, X, y, splitter=2)
result = report.metrics.summarize(scoring_names=scoring_names).frame()
result = report.metrics.summarize(scoring=scoring).frame()
assert result.shape == (len(expected_index), 2)

# Get level 0 names if MultiIndex, otherwise get column names
Expand Down Expand Up @@ -480,3 +492,31 @@ def test_indicator_favorability(forest_binary_classification_data, aggregate):
assert indicator["Brier score"].tolist() == ["(↘︎)"]
assert indicator["Fit time (s)"].tolist() == ["(↘︎)"]
assert indicator["Predict time (s)"].tolist() == ["(↘︎)"]


def test_overwrite_scoring_names_with_dict_cross_validation(
forest_multiclass_classification_data,
):
"""Test that we can overwrite the scoring names using dict scoring in
CrossValidationReport."""
estimator, X, y = forest_multiclass_classification_data
report = CrossValidationReport(estimator, X, y, splitter=2)

scoring_dict = {
"Custom Precision": "_precision",
"Custom Recall": "_recall",
"Custom ROC AUC": "_roc_auc",
}

result = report.metrics.summarize(scoring=scoring_dict).frame()

# Check that custom names are used
result_index = (
result.index.get_level_values(0).tolist()
if isinstance(result.index, pd.MultiIndex)
else result.index.tolist()
)

assert "Custom Precision" in result_index
assert "Custom Recall" in result_index
assert "Custom ROC AUC" in result_index
Loading
Loading