From 5f88492ebec734d2a58e5c5eefd7676835548976 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 8 Jan 2025 22:56:36 +0100 Subject: [PATCH 1/2] fix: Make _add_scorers compatible with sklearn 1.4 --- .../sklearn/cross_validation/cross_validation_helpers.py | 9 +++++++-- .../cross_validation/cross_validation_reporter.py | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py b/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py index 97a46d61ea..7bead8bb9e 100644 --- a/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py +++ b/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py @@ -58,7 +58,7 @@ def _get_scorers_to_add(estimator, y) -> dict[str, Any]: return {} -def _add_scorers(scorers, scorers_to_add): +def _add_scorers(scorers, scorers_to_add, estimator): """Expand `scorers` with more scorers. The type of the resulting scorers object is dependent on the type of the input @@ -79,6 +79,8 @@ def _add_scorers(scorers, scorers_to_add): The scorer(s) to expand. scorers_to_add : dict[str, str] The scorers to be added. + estimator : estimator + A scikit-learn estimator. Returns ------- @@ -103,8 +105,11 @@ def _add_scorers(scorers, scorers_to_add): from sklearn.metrics._scorer import _MultimetricScorer internal_scorer = _MultimetricScorer( + # NOTE: we pass `estimator` to `check_scoring` for compatibility with + # scikit-learn 1.4. However, because `scoring` is never `None`, this + # estimator will not have any effect. scorers={ - name: check_scoring(estimator=None, scoring=scoring) + name: check_scoring(estimator=estimator, scoring=scoring) if isinstance(scoring, str) else scoring for name, scoring in scorers_to_add.items() diff --git a/skore/src/skore/sklearn/cross_validation/cross_validation_reporter.py b/skore/src/skore/sklearn/cross_validation/cross_validation_reporter.py index 47d9b7d092..60c8ec9e89 100644 --- a/skore/src/skore/sklearn/cross_validation/cross_validation_reporter.py +++ b/skore/src/skore/sklearn/cross_validation/cross_validation_reporter.py @@ -134,7 +134,9 @@ def __init__(self, *args, **kwargs): # Extend scorers with other relevant scorers scorers_to_add = _get_scorers_to_add(self.estimator, self.y) - self._scorers, added_scorers = _add_scorers(self.scorers, scorers_to_add) + self._scorers, added_scorers = _add_scorers( + self.scorers, scorers_to_add, self.estimator + ) self._cv_results = sklearn.model_selection.cross_validate( *args, From bf371d7e23d087b85d8abb8e3977c60f1cb02493 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 8 Jan 2025 23:03:48 +0100 Subject: [PATCH 2/2] uhm a recursion --- .../sklearn/cross_validation/cross_validation_helpers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py b/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py index 7bead8bb9e..ba3f397f8a 100644 --- a/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py +++ b/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py @@ -91,10 +91,12 @@ def _add_scorers(scorers, scorers_to_add, estimator): in `scorers`). """ if scorers is None or isinstance(scorers, str): - new_scorers, added_scorers = _add_scorers({"score": scorers}, scorers_to_add) + new_scorers, added_scorers = _add_scorers( + {"score": scorers}, scorers_to_add, estimator + ) elif isinstance(scorers, (list, tuple)): new_scorers, added_scorers = _add_scorers( - {s: s for s in scorers}, scorers_to_add + {s: s for s in scorers}, scorers_to_add, estimator ) elif isinstance(scorers, dict): # User-defined metrics have priority