|
31 | 31 | DataSource = Literal["test", "train", "X_y"] |
32 | 32 |
|
33 | 33 |
|
34 | | -Metric = Literal[ |
| 34 | +MetricStr = Literal[ |
35 | 35 | "accuracy", |
36 | 36 | "precision", |
37 | 37 | "recall", |
|
42 | 42 | "rmse", |
43 | 43 | ] |
44 | 44 |
|
45 | | -# If scoring represents a single score, one can use: |
46 | | -# - a single string (see The scoring parameter: defining model evaluation rules); |
| 45 | +# If the metric parameter represents a single metric, one can use: |
| 46 | +# - a single string (see The metric parameter: defining model evaluation rules); |
47 | 47 | # - a callable (see Callable scorers) that returns a single value. |
48 | | -# If scoring represents multiple scores, one can use: |
| 48 | +# If the metric parameter represents multiple metrics, one can use: |
49 | 49 | # - a list or tuple of unique strings; |
50 | 50 | # - a callable returning a dictionary where the keys are the metric names |
51 | 51 | # and the values are the metric scores; |
52 | 52 | # - a dictionary with metric names as keys and callables a values. |
53 | | -Scoring = Metric | Callable | Iterable[Metric] | dict[str, Callable] |
| 53 | +Metric = MetricStr | Callable | Iterable[MetricStr] | dict[str, Callable] |
54 | 54 |
|
55 | | -metric_to_scorer: dict[Metric, Callable] = { |
| 55 | +metric_to_scorer: dict[MetricStr, Callable] = { |
56 | 56 | "accuracy": make_scorer(metrics.accuracy_score), |
57 | 57 | "precision": make_scorer(metrics.precision_score), |
58 | 58 | "recall": make_scorer(metrics.recall_score), |
@@ -90,7 +90,7 @@ def _get_feature_names(estimator, X, transformer=None) -> list[str]: |
90 | 90 | return [f"Feature #{i}" for i in range(X.shape[1])] |
91 | 91 |
|
92 | 92 |
|
93 | | -def _check_metric(metric: Any) -> Scoring | None: |
| 93 | +def _check_metric(metric: Any) -> Metric | None: |
94 | 94 | """Check that `metric` is valid, and convert it to a suitable form as needed. |
95 | 95 |
|
96 | 96 | If `metric` is a list of strings, it is checked against our own metric names. |
@@ -153,7 +153,7 @@ def _check_metric(metric: Any) -> Scoring | None: |
153 | 153 | elif isinstance(metric, str): |
154 | 154 | if metric in metric_to_scorer: |
155 | 155 | # Convert to scorer |
156 | | - return {metric: metric_to_scorer[cast(Metric, metric)]} |
| 156 | + return {metric: metric_to_scorer[cast(MetricStr, metric)]} |
157 | 157 | raise TypeError( |
158 | 158 | "If metric is a string, it must be one of " |
159 | 159 | f"{list(metric_to_scorer.keys())}; got '{metric}'" |
@@ -338,7 +338,7 @@ def permutation( |
338 | 338 | X: ArrayLike | None = None, |
339 | 339 | y: ArrayLike | None = None, |
340 | 340 | aggregate: Aggregate | None = None, |
341 | | - metric: Scoring | None = None, |
| 341 | + metric: Metric | None = None, |
342 | 342 | n_repeats: int = 5, |
343 | 343 | max_samples: float = 1.0, |
344 | 344 | n_jobs: int | None = None, |
@@ -554,7 +554,7 @@ def _feature_permutation( |
554 | 554 | X: ArrayLike | None, |
555 | 555 | y: ArrayLike | None, |
556 | 556 | aggregate: Aggregate | None, |
557 | | - metric: Scoring | None, |
| 557 | + metric: Metric | None, |
558 | 558 | n_repeats: int, |
559 | 559 | max_samples: float, |
560 | 560 | n_jobs: int | None, |
|
0 commit comments