Skip to content

Commit f95d6f0

Browse files
committed
continue harmonizing scoring --> metric name in feat imp accessor
1 parent e07750b commit f95d6f0

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

skore/src/skore/_sklearn/_estimator/feature_importance_accessor.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
DataSource = Literal["test", "train", "X_y"]
3232

3333

34-
Metric = Literal[
34+
MetricStr = Literal[
3535
"accuracy",
3636
"precision",
3737
"recall",
@@ -42,17 +42,17 @@
4242
"rmse",
4343
]
4444

45-
# If scoring represents a single score, one can use:
46-
# - a single string (see The scoring parameter: defining model evaluation rules);
45+
# If the metric parameter represents a single metric, one can use:
46+
# - a single string (see The metric parameter: defining model evaluation rules);
4747
# - a callable (see Callable scorers) that returns a single value.
48-
# If scoring represents multiple scores, one can use:
48+
# If the metric parameter represents multiple metrics, one can use:
4949
# - a list or tuple of unique strings;
5050
# - a callable returning a dictionary where the keys are the metric names
5151
# and the values are the metric scores;
5252
# - a dictionary with metric names as keys and callables a values.
53-
Scoring = Metric | Callable | Iterable[Metric] | dict[str, Callable]
53+
Metric = MetricStr | Callable | Iterable[MetricStr] | dict[str, Callable]
5454

55-
metric_to_scorer: dict[Metric, Callable] = {
55+
metric_to_scorer: dict[MetricStr, Callable] = {
5656
"accuracy": make_scorer(metrics.accuracy_score),
5757
"precision": make_scorer(metrics.precision_score),
5858
"recall": make_scorer(metrics.recall_score),
@@ -90,7 +90,7 @@ def _get_feature_names(estimator, X, transformer=None) -> list[str]:
9090
return [f"Feature #{i}" for i in range(X.shape[1])]
9191

9292

93-
def _check_metric(metric: Any) -> Scoring | None:
93+
def _check_metric(metric: Any) -> Metric | None:
9494
"""Check that `metric` is valid, and convert it to a suitable form as needed.
9595
9696
If `metric` is a list of strings, it is checked against our own metric names.
@@ -153,7 +153,7 @@ def _check_metric(metric: Any) -> Scoring | None:
153153
elif isinstance(metric, str):
154154
if metric in metric_to_scorer:
155155
# Convert to scorer
156-
return {metric: metric_to_scorer[cast(Metric, metric)]}
156+
return {metric: metric_to_scorer[cast(MetricStr, metric)]}
157157
raise TypeError(
158158
"If metric is a string, it must be one of "
159159
f"{list(metric_to_scorer.keys())}; got '{metric}'"
@@ -338,7 +338,7 @@ def permutation(
338338
X: ArrayLike | None = None,
339339
y: ArrayLike | None = None,
340340
aggregate: Aggregate | None = None,
341-
metric: Scoring | None = None,
341+
metric: Metric | None = None,
342342
n_repeats: int = 5,
343343
max_samples: float = 1.0,
344344
n_jobs: int | None = None,
@@ -554,7 +554,7 @@ def _feature_permutation(
554554
X: ArrayLike | None,
555555
y: ArrayLike | None,
556556
aggregate: Aggregate | None,
557-
metric: Scoring | None,
557+
metric: Metric | None,
558558
n_repeats: int,
559559
max_samples: float,
560560
n_jobs: int | None,

0 commit comments

Comments
 (0)