diff --git a/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py b/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py index 97a46d61ea..ba3f397f8a 100644 --- a/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py +++ b/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py @@ -58,7 +58,7 @@ def _get_scorers_to_add(estimator, y) -> dict[str, Any]: return {} -def _add_scorers(scorers, scorers_to_add): +def _add_scorers(scorers, scorers_to_add, estimator): """Expand `scorers` with more scorers. The type of the resulting scorers object is dependent on the type of the input @@ -79,6 +79,8 @@ def _add_scorers(scorers, scorers_to_add): The scorer(s) to expand. scorers_to_add : dict[str, str] The scorers to be added. + estimator : estimator + A scikit-learn estimator. Returns ------- @@ -89,10 +91,12 @@ def _add_scorers(scorers, scorers_to_add): in `scorers`). """ if scorers is None or isinstance(scorers, str): - new_scorers, added_scorers = _add_scorers({"score": scorers}, scorers_to_add) + new_scorers, added_scorers = _add_scorers( + {"score": scorers}, scorers_to_add, estimator + ) elif isinstance(scorers, (list, tuple)): new_scorers, added_scorers = _add_scorers( - {s: s for s in scorers}, scorers_to_add + {s: s for s in scorers}, scorers_to_add, estimator ) elif isinstance(scorers, dict): # User-defined metrics have priority @@ -103,8 +107,11 @@ def _add_scorers(scorers, scorers_to_add): from sklearn.metrics._scorer import _MultimetricScorer internal_scorer = _MultimetricScorer( + # NOTE: we pass `estimator` to `check_scoring` for compatibility with + # scikit-learn 1.4. However, because `scoring` is never `None`, this + # estimator will not have any effect. scorers={ - name: check_scoring(estimator=None, scoring=scoring) + name: check_scoring(estimator=estimator, scoring=scoring) if isinstance(scoring, str) else scoring for name, scoring in scorers_to_add.items() diff --git a/skore/src/skore/sklearn/cross_validation/cross_validation_reporter.py b/skore/src/skore/sklearn/cross_validation/cross_validation_reporter.py index 47d9b7d092..60c8ec9e89 100644 --- a/skore/src/skore/sklearn/cross_validation/cross_validation_reporter.py +++ b/skore/src/skore/sklearn/cross_validation/cross_validation_reporter.py @@ -134,7 +134,9 @@ def __init__(self, *args, **kwargs): # Extend scorers with other relevant scorers scorers_to_add = _get_scorers_to_add(self.estimator, self.y) - self._scorers, added_scorers = _add_scorers(self.scorers, scorers_to_add) + self._scorers, added_scorers = _add_scorers( + self.scorers, scorers_to_add, self.estimator + ) self._cv_results = sklearn.model_selection.cross_validate( *args,