Skip to content

Commit f0a0b07

Browse files
committed
adapt var names following review
1 parent 1f29b65 commit f0a0b07

File tree

1 file changed

+38
-51
lines changed

1 file changed

+38
-51
lines changed

skore/src/skore/_sklearn/_estimator/metrics_accessor.py

Lines changed: 38 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,11 @@ class is set to the one provided when creating the report. If `None`,
211211
metric_names = None
212212
if isinstance(metric, dict):
213213
metric_names = list(metric.keys())
214-
metric_list = list(metric.values())
214+
metrics = list(metric.values())
215215
elif metric is not None and not isinstance(metric, list):
216-
metric_list = [metric]
216+
metrics = [metric]
217217
elif isinstance(metric, list):
218-
metric_list = metric
218+
metrics = metric
219219

220220
if data_source == "X_y":
221221
# optimization of the hash computation to avoid recomputing it
@@ -231,35 +231,32 @@ class is set to the one provided when creating the report. If `None`,
231231
if metric is None:
232232
# Equivalent to _get_scorers_to_add
233233
if self._parent._ml_task == "binary-classification":
234-
metric_list = ["_precision", "_recall", "_roc_auc"]
234+
metrics = ["_precision", "_recall", "_roc_auc"]
235235
if hasattr(self._parent._estimator, "predict_proba"):
236-
metric_list.append("_brier_score")
236+
metrics.append("_brier_score")
237237
elif self._parent._ml_task == "multiclass-classification":
238-
metric_list = ["_precision", "_recall"]
238+
metrics = ["_precision", "_recall"]
239239
if hasattr(self._parent._estimator, "predict_proba"):
240-
metric_list += ["_roc_auc", "_log_loss"]
240+
metrics += ["_roc_auc", "_log_loss"]
241241
else:
242-
metric_list = ["_r2", "_rmse"]
243-
metric_list += ["_fit_time", "_predict_time"]
242+
metrics = ["_r2", "_rmse"]
243+
metrics += ["_fit_time", "_predict_time"]
244244

245245
if metric_names is None:
246-
metric_names = [None] * len(metric_list) # type: ignore
246+
metric_names = [None] * len(metrics) # type: ignore
247247

248248
scores = []
249249
favorability_indicator = []
250-
for metric_name, metric_item in zip(metric_names, metric_list, strict=False):
251-
if isinstance(metric_item, str) and not (
252-
(
253-
metric_item.startswith("_")
254-
and metric_item[1:] in self._score_or_loss_info
255-
)
256-
or metric_item in self._score_or_loss_info
250+
for metric_name, metric_ in zip(metric_names, metrics, strict=False):
251+
if isinstance(metric_, str) and not (
252+
(metric_.startswith("_") and metric_[1:] in self._score_or_loss_info)
253+
or metric_ in self._score_or_loss_info
257254
):
258255
try:
259-
metric_item = sklearn_metrics.get_scorer(metric_item)
256+
metric_ = sklearn_metrics.get_scorer(metric_)
260257
except ValueError as err:
261258
raise ValueError(
262-
f"Invalid metric: {metric_item!r}. "
259+
f"Invalid metric: {metric_!r}. "
263260
f"Please use a valid metric from the "
264261
f"list of supported metrics: "
265262
f"{list(self._score_or_loss_info.keys())} "
@@ -275,19 +272,19 @@ class is set to the one provided when creating the report. If `None`,
275272

276273
# NOTE: we have to check specifically for `_BaseScorer` first because this
277274
# is also a callable but it has a special private API that we can leverage
278-
if isinstance(metric_item, _BaseScorer):
275+
if isinstance(metric_, _BaseScorer):
279276
# scorers have the advantage to have scoped defined kwargs
280-
metric_function: Callable = metric_item._score_func
281-
response_method: str | list[str] = metric_item._response_method
277+
metric_function: Callable = metric_._score_func
278+
response_method: str | list[str] = metric_._response_method
282279
metric_fn = partial(
283280
self._custom_metric,
284281
metric_function=metric_function,
285282
response_method=response_method,
286283
)
287284
# forward the additional parameters specific to the scorer
288-
metrics_kwargs = {**metric_item._kwargs}
285+
metrics_kwargs = {**metric_._kwargs}
289286
metrics_kwargs["data_source_hash"] = data_source_hash
290-
metrics_params = inspect.signature(metric_item._score_func).parameters
287+
metrics_params = inspect.signature(metric_._score_func).parameters
291288
if "pos_label" in metrics_params:
292289
if pos_label is not None and "pos_label" in metrics_kwargs:
293290
if pos_label != metrics_kwargs["pos_label"]:
@@ -300,60 +297,50 @@ class is set to the one provided when creating the report. If `None`,
300297
elif pos_label is not None:
301298
metrics_kwargs["pos_label"] = pos_label
302299
if metric_name is None:
303-
metric_name = metric_item._score_func.__name__.replace(
304-
"_", " "
305-
).title()
306-
metric_favorability = "(↗︎)" if metric_item._sign == 1 else "(↘︎)"
300+
metric_name = metric_._score_func.__name__.replace("_", " ").title()
301+
metric_favorability = "(↗︎)" if metric_._sign == 1 else "(↘︎)"
307302
favorability_indicator.append(metric_favorability)
308-
elif isinstance(metric_item, str) or callable(metric_item):
309-
if isinstance(metric_item, str):
303+
elif isinstance(metric_, str) or callable(metric_):
304+
if isinstance(metric_, str):
310305
# Handle built-in metrics (with underscore prefix)
311306
if (
312-
metric_item.startswith("_")
313-
and metric_item[1:] in self._score_or_loss_info
307+
metric_.startswith("_")
308+
and metric_[1:] in self._score_or_loss_info
314309
):
315-
metric_fn = getattr(self, metric_item)
310+
metric_fn = getattr(self, metric_)
316311
metrics_kwargs = {"data_source_hash": data_source_hash}
317312
if metric_name is None:
318313
metric_name = (
319-
f"{self._score_or_loss_info[metric_item[1:]]['name']}"
314+
f"{self._score_or_loss_info[metric_[1:]]['name']}"
320315
)
321-
metric_favorability = self._score_or_loss_info[metric_item[1:]][
316+
metric_favorability = self._score_or_loss_info[metric_[1:]][
322317
"icon"
323318
]
324319

325320
# Handle built-in metrics (without underscore prefix)
326-
elif metric_item in self._score_or_loss_info:
327-
metric_fn = getattr(self, f"_{metric_item}")
321+
elif metric_ in self._score_or_loss_info:
322+
metric_fn = getattr(self, f"_{metric_}")
328323
metrics_kwargs = {"data_source_hash": data_source_hash}
329324
if metric_name is None:
330-
metric_name = (
331-
f"{self._score_or_loss_info[metric_item]['name']}"
332-
)
333-
metric_favorability = self._score_or_loss_info[metric_item][
334-
"icon"
335-
]
325+
metric_name = f"{self._score_or_loss_info[metric_]['name']}"
326+
metric_favorability = self._score_or_loss_info[metric_]["icon"]
336327
else:
337328
# Handle callable metrics
338-
metric_fn = partial(
339-
self._custom_metric, metric_function=metric_item
340-
)
329+
metric_fn = partial(self._custom_metric, metric_function=metric_)
341330
if metric_kwargs is None:
342331
metrics_kwargs = {}
343332
else:
344333
# check if we should pass any parameters specific to the metric
345334
# callable
346-
metric_callable_params = inspect.signature(
347-
metric_item
348-
).parameters
335+
metric_callable_params = inspect.signature(metric_).parameters
349336
metrics_kwargs = {
350337
param: metric_kwargs[param]
351338
for param in metric_callable_params
352339
if param in metric_kwargs
353340
}
354341
metrics_kwargs["data_source_hash"] = data_source_hash
355342
if metric_name is None:
356-
metric_name = metric_item.__name__
343+
metric_name = metric_.__name__
357344
metric_favorability = ""
358345
favorability_indicator.append(metric_favorability)
359346

@@ -366,7 +353,7 @@ class is set to the one provided when creating the report. If `None`,
366353
metrics_kwargs["pos_label"] = pos_label
367354
else:
368355
raise ValueError(
369-
f"Invalid type of metric: {type(metric_item)} for {metric_item!r}"
356+
f"Invalid type of metric: {type(metric_)} for {metric_!r}"
370357
)
371358

372359
score = metric_fn(data_source=data_source, X=X, y=y, **metrics_kwargs)

0 commit comments

Comments
 (0)