@@ -211,11 +211,11 @@ class is set to the one provided when creating the report. If `None`,
211211 metric_names = None
212212 if isinstance (metric , dict ):
213213 metric_names = list (metric .keys ())
214- metric_list = list (metric .values ())
214+ metrics = list (metric .values ())
215215 elif metric is not None and not isinstance (metric , list ):
216- metric_list = [metric ]
216+ metrics = [metric ]
217217 elif isinstance (metric , list ):
218- metric_list = metric
218+ metrics = metric
219219
220220 if data_source == "X_y" :
221221 # optimization of the hash computation to avoid recomputing it
@@ -231,35 +231,32 @@ class is set to the one provided when creating the report. If `None`,
231231 if metric is None :
232232 # Equivalent to _get_scorers_to_add
233233 if self ._parent ._ml_task == "binary-classification" :
234- metric_list = ["_precision" , "_recall" , "_roc_auc" ]
234+ metrics = ["_precision" , "_recall" , "_roc_auc" ]
235235 if hasattr (self ._parent ._estimator , "predict_proba" ):
236- metric_list .append ("_brier_score" )
236+ metrics .append ("_brier_score" )
237237 elif self ._parent ._ml_task == "multiclass-classification" :
238- metric_list = ["_precision" , "_recall" ]
238+ metrics = ["_precision" , "_recall" ]
239239 if hasattr (self ._parent ._estimator , "predict_proba" ):
240- metric_list += ["_roc_auc" , "_log_loss" ]
240+ metrics += ["_roc_auc" , "_log_loss" ]
241241 else :
242- metric_list = ["_r2" , "_rmse" ]
243- metric_list += ["_fit_time" , "_predict_time" ]
242+ metrics = ["_r2" , "_rmse" ]
243+ metrics += ["_fit_time" , "_predict_time" ]
244244
245245 if metric_names is None :
246- metric_names = [None ] * len (metric_list ) # type: ignore
246+ metric_names = [None ] * len (metrics ) # type: ignore
247247
248248 scores = []
249249 favorability_indicator = []
250- for metric_name , metric_item in zip (metric_names , metric_list , strict = False ):
251- if isinstance (metric_item , str ) and not (
252- (
253- metric_item .startswith ("_" )
254- and metric_item [1 :] in self ._score_or_loss_info
255- )
256- or metric_item in self ._score_or_loss_info
250+ for metric_name , metric_ in zip (metric_names , metrics , strict = False ):
251+ if isinstance (metric_ , str ) and not (
252+ (metric_ .startswith ("_" ) and metric_ [1 :] in self ._score_or_loss_info )
253+ or metric_ in self ._score_or_loss_info
257254 ):
258255 try :
259- metric_item = sklearn_metrics .get_scorer (metric_item )
256+ metric_ = sklearn_metrics .get_scorer (metric_ )
260257 except ValueError as err :
261258 raise ValueError (
262- f"Invalid metric: { metric_item !r} . "
259+ f"Invalid metric: { metric_ !r} . "
263260 f"Please use a valid metric from the "
264261 f"list of supported metrics: "
265262 f"{ list (self ._score_or_loss_info .keys ())} "
@@ -275,19 +272,19 @@ class is set to the one provided when creating the report. If `None`,
275272
276273 # NOTE: we have to check specifically for `_BaseScorer` first because this
277274 # is also a callable but it has a special private API that we can leverage
278- if isinstance (metric_item , _BaseScorer ):
275+ if isinstance (metric_ , _BaseScorer ):
279276 # scorers have the advantage to have scoped defined kwargs
280- metric_function : Callable = metric_item ._score_func
281- response_method : str | list [str ] = metric_item ._response_method
277+ metric_function : Callable = metric_ ._score_func
278+ response_method : str | list [str ] = metric_ ._response_method
282279 metric_fn = partial (
283280 self ._custom_metric ,
284281 metric_function = metric_function ,
285282 response_method = response_method ,
286283 )
287284 # forward the additional parameters specific to the scorer
288- metrics_kwargs = {** metric_item ._kwargs }
285+ metrics_kwargs = {** metric_ ._kwargs }
289286 metrics_kwargs ["data_source_hash" ] = data_source_hash
290- metrics_params = inspect .signature (metric_item ._score_func ).parameters
287+ metrics_params = inspect .signature (metric_ ._score_func ).parameters
291288 if "pos_label" in metrics_params :
292289 if pos_label is not None and "pos_label" in metrics_kwargs :
293290 if pos_label != metrics_kwargs ["pos_label" ]:
@@ -300,60 +297,50 @@ class is set to the one provided when creating the report. If `None`,
300297 elif pos_label is not None :
301298 metrics_kwargs ["pos_label" ] = pos_label
302299 if metric_name is None :
303- metric_name = metric_item ._score_func .__name__ .replace (
304- "_" , " "
305- ).title ()
306- metric_favorability = "(↗︎)" if metric_item ._sign == 1 else "(↘︎)"
300+ metric_name = metric_ ._score_func .__name__ .replace ("_" , " " ).title ()
301+ metric_favorability = "(↗︎)" if metric_ ._sign == 1 else "(↘︎)"
307302 favorability_indicator .append (metric_favorability )
308- elif isinstance (metric_item , str ) or callable (metric_item ):
309- if isinstance (metric_item , str ):
303+ elif isinstance (metric_ , str ) or callable (metric_ ):
304+ if isinstance (metric_ , str ):
310305 # Handle built-in metrics (with underscore prefix)
311306 if (
312- metric_item .startswith ("_" )
313- and metric_item [1 :] in self ._score_or_loss_info
307+ metric_ .startswith ("_" )
308+ and metric_ [1 :] in self ._score_or_loss_info
314309 ):
315- metric_fn = getattr (self , metric_item )
310+ metric_fn = getattr (self , metric_ )
316311 metrics_kwargs = {"data_source_hash" : data_source_hash }
317312 if metric_name is None :
318313 metric_name = (
319- f"{ self ._score_or_loss_info [metric_item [1 :]]['name' ]} "
314+ f"{ self ._score_or_loss_info [metric_ [1 :]]['name' ]} "
320315 )
321- metric_favorability = self ._score_or_loss_info [metric_item [1 :]][
316+ metric_favorability = self ._score_or_loss_info [metric_ [1 :]][
322317 "icon"
323318 ]
324319
325320 # Handle built-in metrics (without underscore prefix)
326- elif metric_item in self ._score_or_loss_info :
327- metric_fn = getattr (self , f"_{ metric_item } " )
321+ elif metric_ in self ._score_or_loss_info :
322+ metric_fn = getattr (self , f"_{ metric_ } " )
328323 metrics_kwargs = {"data_source_hash" : data_source_hash }
329324 if metric_name is None :
330- metric_name = (
331- f"{ self ._score_or_loss_info [metric_item ]['name' ]} "
332- )
333- metric_favorability = self ._score_or_loss_info [metric_item ][
334- "icon"
335- ]
325+ metric_name = f"{ self ._score_or_loss_info [metric_ ]['name' ]} "
326+ metric_favorability = self ._score_or_loss_info [metric_ ]["icon" ]
336327 else :
337328 # Handle callable metrics
338- metric_fn = partial (
339- self ._custom_metric , metric_function = metric_item
340- )
329+ metric_fn = partial (self ._custom_metric , metric_function = metric_ )
341330 if metric_kwargs is None :
342331 metrics_kwargs = {}
343332 else :
344333 # check if we should pass any parameters specific to the metric
345334 # callable
346- metric_callable_params = inspect .signature (
347- metric_item
348- ).parameters
335+ metric_callable_params = inspect .signature (metric_ ).parameters
349336 metrics_kwargs = {
350337 param : metric_kwargs [param ]
351338 for param in metric_callable_params
352339 if param in metric_kwargs
353340 }
354341 metrics_kwargs ["data_source_hash" ] = data_source_hash
355342 if metric_name is None :
356- metric_name = metric_item .__name__
343+ metric_name = metric_ .__name__
357344 metric_favorability = ""
358345 favorability_indicator .append (metric_favorability )
359346
@@ -366,7 +353,7 @@ class is set to the one provided when creating the report. If `None`,
366353 metrics_kwargs ["pos_label" ] = pos_label
367354 else :
368355 raise ValueError (
369- f"Invalid type of metric: { type (metric_item )} for { metric_item !r} "
356+ f"Invalid type of metric: { type (metric_ )} for { metric_ !r} "
370357 )
371358
372359 score = metric_fn (data_source = data_source , X = X , y = y , ** metrics_kwargs )
0 commit comments