|
8 | 8 | CrossValidationItem, |
9 | 9 | ItemTypeError, |
10 | 10 | _hash_numpy, |
| 11 | + _metric_favorability, |
11 | 12 | ) |
12 | 13 | from skore.sklearn.cross_validation import CrossValidationReporter |
13 | 14 | from skore.sklearn.cross_validation.cross_validation_reporter import ( |
@@ -86,9 +87,11 @@ def test_factory(self, mock_nowstr, reporter): |
86 | 87 | assert item.cv_results_serialized == {"test_score": [1, 2, 3]} |
87 | 88 | assert item.estimator_info == { |
88 | 89 | "name": reporter.estimator.__class__.__name__, |
89 | | - "params": {} |
90 | | - if isinstance(reporter.estimator, FakeEstimatorNoGetParams) |
91 | | - else {"alpha": {"value": "3", "default": True}}, |
| 90 | + "params": ( |
| 91 | + {} |
| 92 | + if isinstance(reporter.estimator, FakeEstimatorNoGetParams) |
| 93 | + else {"alpha": {"value": "3", "default": True}} |
| 94 | + ), |
92 | 95 | "module": "tests.unit.item.test_cross_validation_item", |
93 | 96 | } |
94 | 97 | assert item.X_info == { |
@@ -137,3 +140,53 @@ def test_get_serializable_dict(self, monkeypatch, mock_nowstr): |
137 | 140 | ], |
138 | 141 | } |
139 | 142 | ] |
| 143 | + |
| 144 | + @pytest.mark.parametrize( |
| 145 | + "metric,expected", |
| 146 | + [ |
| 147 | + # greater_is_better metrics (exact matches) |
| 148 | + ("accuracy", "greater_is_better"), |
| 149 | + ("balanced_accuracy", "greater_is_better"), |
| 150 | + ("top_k_accuracy", "greater_is_better"), |
| 151 | + ("average_precision", "greater_is_better"), |
| 152 | + ("f1", "greater_is_better"), |
| 153 | + ("precision", "greater_is_better"), |
| 154 | + ("recall", "greater_is_better"), |
| 155 | + ("jaccard", "greater_is_better"), |
| 156 | + ("roc_auc", "greater_is_better"), |
| 157 | + ("r2", "greater_is_better"), |
| 158 | + # greater_is_better metrics (pattern matches) |
| 159 | + ("weighted_f1", "greater_is_better"), |
| 160 | + ("macro_precision", "greater_is_better"), |
| 161 | + ("micro_recall", "greater_is_better"), |
| 162 | + # greater_is_better by convention (_score suffix) |
| 163 | + ("custom_score", "greater_is_better"), |
| 164 | + ("validation_score", "greater_is_better"), |
| 165 | + # greater_is_better by convention (neg_ prefix) |
| 166 | + ("neg_mean_squared_error", "greater_is_better"), |
| 167 | + ("neg_log_loss", "greater_is_better"), |
| 168 | + # the same one but without the neg_ prefix |
| 169 | + ("mean_squared_error", "lower_is_better"), |
| 170 | + ("log_loss", "lower_is_better"), |
| 171 | + # lower_is_better metrics (exact matches) |
| 172 | + ("fit_time", "lower_is_better"), |
| 173 | + ("score_time", "lower_is_better"), |
| 174 | + # lower_is_better by convention (suffixes) |
| 175 | + ("mean_squared_error", "lower_is_better"), |
| 176 | + ("mean_absolute_error", "lower_is_better"), |
| 177 | + ("binary_crossentropy_loss", "lower_is_better"), |
| 178 | + ("hinge_loss", "lower_is_better"), |
| 179 | + ("entropy_deviance", "lower_is_better"), |
| 180 | + # unknown metrics |
| 181 | + ("custom_metric", "unknown"), |
| 182 | + ("undefined", "unknown"), |
| 183 | + ("", "unknown"), |
| 184 | + ], |
| 185 | + ) |
| 186 | + def test_metric_favorability(self, metric, expected): |
| 187 | + """Test the _metric_favorability function with various metric names. |
| 188 | +
|
| 189 | + Non-regression test for: |
| 190 | + https://github.com/probabl-ai/skore/issues/1061 |
| 191 | + """ |
| 192 | + assert _metric_favorability(metric) == expected |
0 commit comments