Skip to content

Commit 3c05371

Browse files
authored
fix: Do not compute Brier score when predict_proba is not available (#1064)
closes #1050 As specified, do not try to compute Brier score with an estimator that does not provide probability estimate (i.e. does not have a `predict_proba` method). Added a non-regression test as well.
1 parent 29d3689 commit 3c05371

File tree

2 files changed

+80
-18
lines changed

2 files changed

+80
-18
lines changed

โ€Žskore/src/skore/sklearn/cross_validation/cross_validation_helpers.pyโ€Ž

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,27 @@ def _get_scorers_to_add(estimator, y) -> dict[str, Any]:
3333
),
3434
}
3535
if ml_task == "binary-classification":
36-
return {
37-
"roc_auc": "roc_auc",
38-
"brier_score_loss": metrics.make_scorer(
39-
metrics.brier_score_loss, response_method="predict_proba"
40-
),
36+
scorers_to_add = {
4137
"recall": "recall",
4238
"precision": "precision",
39+
"roc_auc": "roc_auc",
4340
}
44-
if ml_task == "multiclass-classification":
4541
if hasattr(estimator, "predict_proba"):
46-
return {
47-
"recall_weighted": "recall_weighted",
48-
"precision_weighted": "precision_weighted",
49-
"roc_auc_ovr_weighted": "roc_auc_ovr_weighted",
50-
"log_loss": metrics.make_scorer(
51-
metrics.log_loss, response_method="predict_proba"
52-
),
53-
}
54-
return {
42+
scorers_to_add["brier_score_loss"] = metrics.make_scorer(
43+
metrics.brier_score_loss, response_method="predict_proba"
44+
)
45+
return scorers_to_add
46+
if ml_task == "multiclass-classification":
47+
scorers_to_add = {
5548
"recall_weighted": "recall_weighted",
5649
"precision_weighted": "precision_weighted",
5750
}
51+
if hasattr(estimator, "predict_proba"):
52+
scorers_to_add["roc_auc_ovr_weighted"] = "roc_auc_ovr_weighted"
53+
scorers_to_add["log_loss"] = metrics.make_scorer(
54+
metrics.log_loss, response_method="predict_proba"
55+
)
56+
return scorers_to_add
5857
return {}
5958

6059

@@ -104,9 +103,11 @@ def _add_scorers(scorers, scorers_to_add):
104103

105104
internal_scorer = _MultimetricScorer(
106105
scorers={
107-
name: check_scoring(estimator=None, scoring=scoring)
108-
if isinstance(scoring, str)
109-
else scoring
106+
name: (
107+
check_scoring(estimator=None, scoring=scoring)
108+
if isinstance(scoring, str)
109+
else scoring
110+
)
110111
for name, scoring in scorers_to_add.items()
111112
}
112113
)

โ€Žskore/tests/unit/sklearn/test_cross_validate.pyโ€Ž

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
import pytest
2+
from sklearn.datasets import make_classification, make_regression
3+
from sklearn.linear_model import LinearRegression, LogisticRegression
4+
from sklearn.svm import SVC
15
from skore.sklearn.cross_validation import CrossValidationReporter
6+
from skore.sklearn.cross_validation.cross_validation_helpers import _get_scorers_to_add
27

38

49
def prepare_cv():
@@ -35,3 +40,59 @@ def test_cross_validate_return_estimator():
3540
assert "indices" in reporter.cv_results
3641
assert "estimator" in reporter._cv_results
3742
assert "indices" in reporter._cv_results
43+
44+
45+
@pytest.mark.parametrize(
46+
"estimator,dataset_func,dataset_kwargs,expected_keys",
47+
[
48+
pytest.param(
49+
LinearRegression(),
50+
make_regression,
51+
{"n_targets": 1},
52+
{"r2", "root_mean_squared_error"},
53+
id="regression",
54+
),
55+
pytest.param(
56+
LogisticRegression(),
57+
make_classification,
58+
{"n_classes": 2},
59+
{"recall", "precision", "roc_auc", "brier_score_loss"},
60+
id="binary_classification_with_proba",
61+
),
62+
pytest.param(
63+
SVC(probability=False),
64+
make_classification,
65+
{"n_classes": 2},
66+
{"recall", "precision", "roc_auc"},
67+
id="binary_classification_without_proba",
68+
),
69+
pytest.param(
70+
LogisticRegression(),
71+
make_classification,
72+
{"n_classes": 3, "n_clusters_per_class": 1},
73+
{
74+
"recall_weighted",
75+
"precision_weighted",
76+
"roc_auc_ovr_weighted",
77+
"log_loss",
78+
},
79+
id="multiclass_with_proba",
80+
),
81+
pytest.param(
82+
SVC(probability=False),
83+
make_classification,
84+
{"n_classes": 3, "n_clusters_per_class": 1},
85+
{"recall_weighted", "precision_weighted"},
86+
id="multiclass_without_proba",
87+
),
88+
],
89+
)
90+
def test_get_scorers_to_add(estimator, dataset_func, dataset_kwargs, expected_keys):
91+
"""Check that the scorers to add are correct.
92+
93+
Non-regression test for:
94+
https://github.com/probabl-ai/skore/issues/1050
95+
"""
96+
X, y = dataset_func(**dataset_kwargs)
97+
scorers = _get_scorers_to_add(estimator, y)
98+
assert set(scorers.keys()) == expected_keys

0 commit comments

Comments
ย (0)