Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _get_scorers_to_add(estimator, y) -> dict[str, Any]:
return {}


def _add_scorers(scorers, scorers_to_add):
def _add_scorers(scorers, scorers_to_add, estimator):
"""Expand `scorers` with more scorers.

The type of the resulting scorers object is dependent on the type of the input
Expand All @@ -79,6 +79,8 @@ def _add_scorers(scorers, scorers_to_add):
The scorer(s) to expand.
scorers_to_add : dict[str, str]
The scorers to be added.
estimator : estimator
A scikit-learn estimator.

Returns
-------
Expand All @@ -89,10 +91,12 @@ def _add_scorers(scorers, scorers_to_add):
in `scorers`).
"""
if scorers is None or isinstance(scorers, str):
new_scorers, added_scorers = _add_scorers({"score": scorers}, scorers_to_add)
new_scorers, added_scorers = _add_scorers(
{"score": scorers}, scorers_to_add, estimator
)
elif isinstance(scorers, (list, tuple)):
new_scorers, added_scorers = _add_scorers(
{s: s for s in scorers}, scorers_to_add
{s: s for s in scorers}, scorers_to_add, estimator
)
elif isinstance(scorers, dict):
# User-defined metrics have priority
Expand All @@ -103,8 +107,11 @@ def _add_scorers(scorers, scorers_to_add):
from sklearn.metrics._scorer import _MultimetricScorer

internal_scorer = _MultimetricScorer(
# NOTE: we pass `estimator` to `check_scoring` for compatibility with
# scikit-learn 1.4. However, because `scoring` is never `None`, this
# estimator will not have any effect.
scorers={
name: check_scoring(estimator=None, scoring=scoring)
name: check_scoring(estimator=estimator, scoring=scoring)
if isinstance(scoring, str)
else scoring
for name, scoring in scorers_to_add.items()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def __init__(self, *args, **kwargs):

# Extend scorers with other relevant scorers
scorers_to_add = _get_scorers_to_add(self.estimator, self.y)
self._scorers, added_scorers = _add_scorers(self.scorers, scorers_to_add)
self._scorers, added_scorers = _add_scorers(
self.scorers, scorers_to_add, self.estimator
)

self._cv_results = sklearn.model_selection.cross_validate(
*args,
Expand Down
Loading