Skip to content

Commit 3c759f9

Browse files
fix: Leverage get_scorer and the fact it is a _BaseScorer (#1723)
closes #1715 closes #1714 closes #1716 closes #1686 This PR: - fixes the way we handle scikit-learn scorer names - update the API documentation to provide the difference between built-in metrics and scikit-learn scorer names - fixes the tests that are available - add more tests for metrics that takes additional parameters --------- Co-authored-by: Auguste Baum <[email protected]>
1 parent 7652cc1 commit 3c759f9

File tree

5 files changed

+111
-117
lines changed

5 files changed

+111
-117
lines changed

skore/src/skore/sklearn/_comparison/metrics_accessor.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,18 @@ def report_metrics(
8383
provided when creating the report.
8484
8585
scoring : list of str, callable, or scorer, default=None
86-
The metrics to report. You can get the possible list of strings by calling
87-
`report.metrics.help()`. When passing a callable, it should take as
88-
arguments ``y_true``, ``y_pred`` as the two first arguments. Additional
89-
arguments can be passed as keyword arguments and will be forwarded with
90-
`scoring_kwargs`. If the callable API is too restrictive (e.g. need to pass
91-
same parameter name with different values), you can use scikit-learn scorers
92-
as provided by :func:`sklearn.metrics.make_scorer`.
86+
The metrics to report. The possible values in the list are:
87+
88+
- if a string, either one of the built-in metrics or a scikit-learn scorer
89+
name. You can get the possible list of string using
90+
`report.metrics.help()` or :func:`sklearn.metrics.get_scorer_names` for
91+
the built-in metrics or the scikit-learn scorers, respectively.
92+
- if a callable, it should take as arguments `y_true`, `y_pred` as the two
93+
first arguments. Additional arguments can be passed as keyword arguments
94+
and will be forwarded with `scoring_kwargs`.
95+
- if the callable API is too restrictive (e.g. need to pass
96+
same parameter name with different values), you can use scikit-learn
97+
scorers as provided by :func:`sklearn.metrics.make_scorer`.
9398
9499
scoring_names : list of str, default=None
95100
Used to overwrite the default scoring names in the report. It should be of

skore/src/skore/sklearn/_cross_validation/metrics_accessor.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,18 @@ def report_metrics(
8383
provided when creating the report.
8484
8585
scoring : list of str, callable, or scorer, default=None
86-
The metrics to report. You can get the possible list of string by calling
87-
`report.metrics.help()`. When passing a callable, it should take as
88-
arguments `y_true`, `y_pred` as the two first arguments. Additional
89-
arguments can be passed as keyword arguments and will be forwarded with
90-
`scoring_kwargs`. If the callable API is too restrictive (e.g. need to pass
91-
same parameter name with different values), you can use scikit-learn scorers
92-
as provided by :func:`sklearn.metrics.make_scorer`.
86+
The metrics to report. The possible values in the list are:
87+
88+
- if a string, either one of the built-in metrics or a scikit-learn scorer
89+
name. You can get the possible list of string using
90+
`report.metrics.help()` or :func:`sklearn.metrics.get_scorer_names` for
91+
the built-in metrics or the scikit-learn scorers, respectively.
92+
- if a callable, it should take as arguments `y_true`, `y_pred` as the two
93+
first arguments. Additional arguments can be passed as keyword arguments
94+
and will be forwarded with `scoring_kwargs`.
95+
- if the callable API is too restrictive (e.g. need to pass
96+
same parameter name with different values), you can use scikit-learn
97+
scorers as provided by :func:`sklearn.metrics.make_scorer`.
9398
9499
scoring_names : list of str, default=None
95100
Used to overwrite the default scoring names in the report. It should be of

skore/src/skore/sklearn/_estimator/metrics_accessor.py

Lines changed: 41 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,18 @@ def report_metrics(
9090
provided when creating the report.
9191
9292
scoring : list of str, callable, or scorer, default=None
93-
The metrics to report. You can get the possible list of string by calling
94-
`report.metrics.help()`. When passing a callable, it should take as
95-
arguments `y_true`, `y_pred` as the two first arguments. Additional
96-
arguments can be passed as keyword arguments and will be forwarded with
97-
`scoring_kwargs`. If the callable API is too restrictive (e.g. need to pass
98-
same parameter name with different values), you can use scikit-learn scorers
99-
as provided by :func:`sklearn.metrics.make_scorer`.
93+
The metrics to report. The possible values in the list are:
94+
95+
- if a string, either one of the built-in metrics or a scikit-learn scorer
96+
name. You can get the possible list of string using
97+
`report.metrics.help()` or :func:`sklearn.metrics.get_scorer_names` for
98+
the built-in metrics or the scikit-learn scorers, respectively.
99+
- if a callable, it should take as arguments `y_true`, `y_pred` as the two
100+
first arguments. Additional arguments can be passed as keyword arguments
101+
and will be forwarded with `scoring_kwargs`.
102+
- if the callable API is too restrictive (e.g. need to pass
103+
same parameter name with different values), you can use scikit-learn
104+
scorers as provided by :func:`sklearn.metrics.make_scorer`.
100105
101106
scoring_names : list of str, default=None
102107
Used to overwrite the default scoring names in the report. It should be of
@@ -138,14 +143,13 @@ def report_metrics(
138143
Recall 0.93... (↗︎)
139144
ROC AUC 0.99... (↗︎)
140145
Brier score 0.03... (↘︎)
141-
142146
>>> # Using scikit-learn metrics
143147
>>> report.metrics.report_metrics(
144-
scoring=["neg_log_loss"],
145-
indicator_favorability=True)
146-
LogisticRegression Favorability
147-
Metric
148-
Negative Log Loss -0.10... (↘︎)
148+
... scoring=["f1"], pos_label=1, indicator_favorability=True
149+
... )
150+
LogisticRegression Favorability
151+
Metric Label / Average
152+
F1 Score 1 0.96... (↗︎)
149153
"""
150154
if data_source == "X_y":
151155
# optimization of the hash computation to avoid recomputing it
@@ -194,6 +198,28 @@ def report_metrics(
194198
scores = []
195199
favorability_indicator = []
196200
for metric_name, metric in zip(scoring_names, scoring):
201+
if isinstance(metric, str) and not (
202+
(metric.startswith("_") and metric[1:] in self._SCORE_OR_LOSS_INFO)
203+
or metric in self._SCORE_OR_LOSS_INFO
204+
):
205+
try:
206+
metric = metrics.get_scorer(metric)
207+
except ValueError as err:
208+
raise ValueError(
209+
f"Invalid metric: {metric!r}. "
210+
f"Please use a valid metric from the "
211+
f"list of supported metrics: "
212+
f"{list(self._SCORE_OR_LOSS_INFO.keys())} "
213+
"or a valid scikit-learn scoring string."
214+
) from err
215+
if scoring_kwargs is not None:
216+
raise ValueError(
217+
"The `scoring_kwargs` parameter is not supported when "
218+
"`scoring` is a scikit-learn scorer name. Use the function "
219+
"`sklearn.metrics.make_scorer` to create a scorer with "
220+
"additional parameters."
221+
)
222+
197223
# NOTE: we have to check specifically for `_BaseScorer` first because this
198224
# is also a callable but it has a special private API that we can leverage
199225
if isinstance(metric, _BaseScorer):
@@ -221,8 +247,8 @@ def report_metrics(
221247
elif pos_label is not None:
222248
metrics_kwargs["pos_label"] = pos_label
223249
if metric_name is None:
224-
metric_name = metric._score_func.__name__
225-
metric_favorability = "↗︎" if metric._sign == 1 else "↘︎"
250+
metric_name = metric._score_func.__name__.replace("_", " ").title()
251+
metric_favorability = "(↗︎)" if metric._sign == 1 else "(↘︎)"
226252
favorability_indicator.append(metric_favorability)
227253
elif isinstance(metric, str) or callable(metric):
228254
if isinstance(metric, str):
@@ -248,51 +274,6 @@ def report_metrics(
248274
if metric_name is None:
249275
metric_name = f"{self._SCORE_OR_LOSS_INFO[metric]['name']}"
250276
metric_favorability = self._SCORE_OR_LOSS_INFO[metric]["icon"]
251-
252-
# Handle scikit-learn metrics by trying get_scorer
253-
else:
254-
from sklearn.metrics import get_scorer
255-
256-
try:
257-
scorer = get_scorer(metric)
258-
metric_function = scorer._score_func
259-
response_method = scorer._response_method
260-
261-
display_name = metric
262-
if metric.startswith("neg_"):
263-
display_name = metric[4:].replace("_", " ")
264-
metric_fn = partial(
265-
self._custom_metric,
266-
metric_function=metric_function,
267-
response_method=response_method,
268-
)
269-
metrics_kwargs = {**scorer._kwargs}
270-
metrics_kwargs["data_source_hash"] = data_source_hash
271-
metric_favorability = "↘︎"
272-
favorability_indicator.append(metric_favorability)
273-
274-
if metric_name is None:
275-
metric_name = display_name.title()
276-
277-
metric_fn = partial(
278-
self._custom_metric,
279-
metric_function=metric_function,
280-
response_method=response_method,
281-
)
282-
metrics_kwargs = {**scorer._kwargs}
283-
metrics_kwargs["data_source_hash"] = data_source_hash
284-
metric_favorability = (
285-
"(↘︎)" if metric.startswith("neg_") else "(↗︎)"
286-
)
287-
except ValueError as err:
288-
raise ValueError(
289-
f"Invalid metric: {metric!r}. "
290-
f"Please use a valid metric from the "
291-
f"list of supported metrics: "
292-
f"{list(self._SCORE_OR_LOSS_INFO.keys())} "
293-
"or a valid scikit-learn scoring string."
294-
) from err
295-
favorability_indicator.append(metric_favorability)
296277
else:
297278
# Handle callable metrics
298279
metric_fn = partial(self._custom_metric, metric_function=metric)

skore/tests/unit/sklearn/cross_validation/test_cross_validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ def test_cross_validation_report_custom_metric(binary_classification_data):
885885
response_method="predict",
886886
)
887887
assert result.shape == (1, 2)
888-
assert result.index == ["accuracy_score"]
888+
assert result.index == ["Accuracy Score"]
889889

890890

891891
@pytest.mark.parametrize(
@@ -936,7 +936,7 @@ def predict(self, X):
936936
response_method="predict",
937937
)
938938
assert result.shape == (1, 2)
939-
assert result.index == ["accuracy_score"]
939+
assert result.index == ["Accuracy Score"]
940940

941941

942942
def test_cross_validation_report_brier_score_requires_probabilities():

skore/tests/unit/sklearn/estimator/test_estimator.py

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from sklearn.metrics import (
1616
accuracy_score,
1717
f1_score,
18-
get_scorer,
1918
make_scorer,
2019
median_absolute_error,
2120
r2_score,
@@ -1349,49 +1348,30 @@ def test_estimator_report_average_return_float(binary_classification_data):
13491348
def test_estimator_report_metric_with_neg_metrics(binary_classification_data):
13501349
"""Check that scikit-learn metrics with 'neg_' prefix are handled correctly."""
13511350
classifier, X_test, y_test = binary_classification_data
1352-
report = EstimatorReport(
1353-
classifier,
1354-
X_test=X_test,
1355-
y_test=y_test,
1356-
)
1357-
1358-
# Use scikit-learn's get_scorer to handle neg_log_loss
1359-
scorer = get_scorer("neg_log_loss")
1360-
result = report.metrics.report_metrics(scoring=[scorer])
1351+
report = EstimatorReport(classifier, X_test=X_test, y_test=y_test)
13611352

1362-
# Check that the metric name is displayed properly (as 'log_loss')
1363-
assert "log_loss" in result.index
1364-
1365-
# Get the neg_log_loss score directly - use the fitted model from the report
1366-
neg_log_loss_value = get_scorer("neg_log_loss")(report.estimator_, X_test, y_test)
1367-
1368-
# Check that the reported log_loss matches the absolute value of neg_log_loss
1369-
log_loss_value = result.loc["log_loss", classifier.__class__.__name__]
1370-
assert np.isclose(log_loss_value, abs(neg_log_loss_value))
1353+
result = report.metrics.report_metrics(scoring=["neg_log_loss"])
1354+
assert "Log Loss" in result.index
1355+
assert result.loc["Log Loss", "RandomForestClassifier"] == pytest.approx(
1356+
report.metrics.log_loss()
1357+
)
13711358

13721359

13731360
def test_estimator_report_with_sklearn_scoring_strings(binary_classification_data):
13741361
"""Test that scikit-learn metric strings can be passed to report_metrics."""
13751362
classifier, X_test, y_test = binary_classification_data
1376-
class_report = EstimatorReport(
1377-
classifier,
1378-
X_test=X_test,
1379-
y_test=y_test,
1380-
)
1363+
class_report = EstimatorReport(classifier, X_test=X_test, y_test=y_test)
13811364

1382-
# Test single scikit-learn metric string
13831365
result = class_report.metrics.report_metrics(scoring=["neg_log_loss"])
13841366
assert "Log Loss" in result.index.get_level_values(0)
13851367

1386-
# Test with multiple scikit-learn metrics
13871368
result_multi = class_report.metrics.report_metrics(
13881369
scoring=["accuracy", "neg_log_loss", "roc_auc"], indicator_favorability=True
13891370
)
13901371
assert "Accuracy" in result_multi.index.get_level_values(0)
13911372
assert "Log Loss" in result_multi.index.get_level_values(0)
13921373
assert "ROC AUC" in result_multi.index.get_level_values(0)
13931374

1394-
# Test favorability indicators
13951375
favorability = result_multi.loc["Accuracy"]["Favorability"]
13961376
assert favorability == "(↗︎)"
13971377
favorability = result_multi.loc["Log Loss"]["Favorability"]
@@ -1401,13 +1381,8 @@ def test_estimator_report_with_sklearn_scoring_strings(binary_classification_dat
14011381
def test_estimator_report_with_sklearn_scoring_strings_regression(regression_data):
14021382
"""Test scikit-learn regression metric strings in report_metrics."""
14031383
regressor, X_test, y_test = regression_data
1404-
reg_report = EstimatorReport(
1405-
regressor,
1406-
X_test=X_test,
1407-
y_test=y_test,
1408-
)
1384+
reg_report = EstimatorReport(regressor, X_test=X_test, y_test=y_test)
14091385

1410-
# Test regression metrics
14111386
reg_result = reg_report.metrics.report_metrics(
14121387
scoring=["neg_mean_squared_error", "neg_mean_absolute_error", "r2"],
14131388
indicator_favorability=True,
@@ -1417,21 +1392,15 @@ def test_estimator_report_with_sklearn_scoring_strings_regression(regression_dat
14171392
assert "Mean Absolute Error" in reg_result.index.get_level_values(0)
14181393
assert "R²" in reg_result.index.get_level_values(0)
14191394

1420-
# Check favorability
14211395
assert reg_result.loc["Mean Squared Error"]["Favorability"] == "(↘︎)"
14221396
assert reg_result.loc["R²"]["Favorability"] == "(↗︎)"
14231397

14241398

14251399
def test_estimator_report_with_scoring_strings_regression(regression_data):
14261400
"""Test scikit-learn regression metric strings in report_metrics."""
14271401
regressor, X_test, y_test = regression_data
1428-
reg_report = EstimatorReport(
1429-
regressor,
1430-
X_test=X_test,
1431-
y_test=y_test,
1432-
)
1402+
reg_report = EstimatorReport(regressor, X_test=X_test, y_test=y_test)
14331403

1434-
# Test regression metrics
14351404
reg_result = reg_report.metrics.report_metrics(
14361405
scoring=["neg_mean_squared_error", "neg_mean_absolute_error", "r2"],
14371406
indicator_favorability=True,
@@ -1441,6 +1410,40 @@ def test_estimator_report_with_scoring_strings_regression(regression_data):
14411410
assert "Mean Absolute Error" in reg_result.index.get_level_values(0)
14421411
assert "R²" in reg_result.index.get_level_values(0)
14431412

1444-
# Check favorability
14451413
assert reg_result.loc["Mean Squared Error"]["Favorability"] == "(↘︎)"
14461414
assert reg_result.loc["R²"]["Favorability"] == "(↗︎)"
1415+
1416+
1417+
def test_estimator_report_sklearn_scorer_names_pos_label(binary_classification_data):
1418+
"""Check that `pos_label` is dispatched with scikit-learn scorer names."""
1419+
classifier, X_test, y_test = binary_classification_data
1420+
report = EstimatorReport(classifier, X_test=X_test, y_test=y_test)
1421+
1422+
result = report.metrics.report_metrics(scoring=["f1"], pos_label=0)
1423+
assert "F1 Score" in result.index.get_level_values(0)
1424+
assert 0 in result.index.get_level_values(1)
1425+
f1_scorer = make_scorer(
1426+
f1_score, response_method="predict", average="binary", pos_label=0
1427+
)
1428+
assert result.loc[("F1 Score", 0), "RandomForestClassifier"] == pytest.approx(
1429+
f1_scorer(classifier, X_test, y_test)
1430+
)
1431+
1432+
1433+
def test_estimator_report_sklearn_scorer_names_scoring_kwargs(
1434+
binary_classification_data,
1435+
):
1436+
"""Check that `scoring_kwargs` is not supported when `scoring` is a scikit-learn
1437+
scorer name.
1438+
"""
1439+
classifier, X_test, y_test = binary_classification_data
1440+
report = EstimatorReport(classifier, X_test=X_test, y_test=y_test)
1441+
1442+
err_msg = (
1443+
"The `scoring_kwargs` parameter is not supported when `scoring` is a "
1444+
"scikit-learn scorer name."
1445+
)
1446+
with pytest.raises(ValueError, match=err_msg):
1447+
report.metrics.report_metrics(
1448+
scoring=["f1"], scoring_kwargs={"average": "macro"}
1449+
)

0 commit comments

Comments
 (0)