Skip to content

Commit 776684d

Browse files
fix: Use a more flexible approach to check favorability of metrics (probabl-ai#1063)
closes probabl-ai#1061 ![image](https://github.com/user-attachments/assets/98d002c1-874d-4de2-bb26-5fc16838b2f1) We are more flexible using a regular expression to check the score names. In addition, we take care to test first the `neg_` part that would mean that negative score are therefore "higher is greater" convention. --------- Co-authored-by: Auguste Baum <[email protected]>
1 parent f8c0456 commit 776684d

File tree

2 files changed

+76
-13
lines changed

2 files changed

+76
-13
lines changed

skore/src/skore/item/cross_validation_item.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,29 @@ def _metric_favorability(
7070
metric: str,
7171
) -> Literal["greater_is_better", "lower_is_better", "unknown"]:
7272
greater_is_better_metrics = (
73-
"r2",
74-
"test_r2",
75-
"roc_auc",
76-
"recall",
77-
"recall_weighted",
73+
"accuracy",
74+
"balanced_accuracy",
75+
"top_k_accuracy",
76+
"average_precision",
77+
"f1",
7878
"precision",
79-
"precision_weighted",
80-
"roc_auc_ovr_weighted",
79+
"recall",
80+
"jaccard",
81+
"roc_auc",
82+
"r2",
8183
)
82-
lower_is_better_metrics = ("fit_time", "score_time")
83-
84-
if metric.endswith("_score") or metric in greater_is_better_metrics:
84+
any_match_greater_is_better = any(
85+
re.search(re.escape(pattern), metric) for pattern in greater_is_better_metrics
86+
)
87+
if (
88+
any_match_greater_is_better
89+
# other scikit-learn conventions
90+
or metric.endswith("_score") # score: higher is better
91+
or metric.startswith("neg_") # negative loss: negative of lower is better
92+
):
8593
return "greater_is_better"
94+
95+
lower_is_better_metrics = ("fit_time", "score_time")
8696
if (
8797
metric.endswith("_error")
8898
or metric.endswith("_loss")

skore/tests/unit/item/test_cross_validation_item.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
CrossValidationItem,
99
ItemTypeError,
1010
_hash_numpy,
11+
_metric_favorability,
1112
)
1213
from skore.sklearn.cross_validation import CrossValidationReporter
1314
from skore.sklearn.cross_validation.cross_validation_reporter import (
@@ -86,9 +87,11 @@ def test_factory(self, mock_nowstr, reporter):
8687
assert item.cv_results_serialized == {"test_score": [1, 2, 3]}
8788
assert item.estimator_info == {
8889
"name": reporter.estimator.__class__.__name__,
89-
"params": {}
90-
if isinstance(reporter.estimator, FakeEstimatorNoGetParams)
91-
else {"alpha": {"value": "3", "default": True}},
90+
"params": (
91+
{}
92+
if isinstance(reporter.estimator, FakeEstimatorNoGetParams)
93+
else {"alpha": {"value": "3", "default": True}}
94+
),
9295
"module": "tests.unit.item.test_cross_validation_item",
9396
}
9497
assert item.X_info == {
@@ -137,3 +140,53 @@ def test_get_serializable_dict(self, monkeypatch, mock_nowstr):
137140
],
138141
}
139142
]
143+
144+
@pytest.mark.parametrize(
145+
"metric,expected",
146+
[
147+
# greater_is_better metrics (exact matches)
148+
("accuracy", "greater_is_better"),
149+
("balanced_accuracy", "greater_is_better"),
150+
("top_k_accuracy", "greater_is_better"),
151+
("average_precision", "greater_is_better"),
152+
("f1", "greater_is_better"),
153+
("precision", "greater_is_better"),
154+
("recall", "greater_is_better"),
155+
("jaccard", "greater_is_better"),
156+
("roc_auc", "greater_is_better"),
157+
("r2", "greater_is_better"),
158+
# greater_is_better metrics (pattern matches)
159+
("weighted_f1", "greater_is_better"),
160+
("macro_precision", "greater_is_better"),
161+
("micro_recall", "greater_is_better"),
162+
# greater_is_better by convention (_score suffix)
163+
("custom_score", "greater_is_better"),
164+
("validation_score", "greater_is_better"),
165+
# greater_is_better by convention (neg_ prefix)
166+
("neg_mean_squared_error", "greater_is_better"),
167+
("neg_log_loss", "greater_is_better"),
168+
# the same one but without the neg_ prefix
169+
("mean_squared_error", "lower_is_better"),
170+
("log_loss", "lower_is_better"),
171+
# lower_is_better metrics (exact matches)
172+
("fit_time", "lower_is_better"),
173+
("score_time", "lower_is_better"),
174+
# lower_is_better by convention (suffixes)
175+
("mean_squared_error", "lower_is_better"),
176+
("mean_absolute_error", "lower_is_better"),
177+
("binary_crossentropy_loss", "lower_is_better"),
178+
("hinge_loss", "lower_is_better"),
179+
("entropy_deviance", "lower_is_better"),
180+
# unknown metrics
181+
("custom_metric", "unknown"),
182+
("undefined", "unknown"),
183+
("", "unknown"),
184+
],
185+
)
186+
def test_metric_favorability(self, metric, expected):
187+
"""Test the _metric_favorability function with various metric names.
188+
189+
Non-regression test for:
190+
https://github.com/probabl-ai/skore/issues/1061
191+
"""
192+
assert _metric_favorability(metric) == expected

0 commit comments

Comments
 (0)