@@ -331,6 +331,7 @@ def compile(
331331 else self ._compile_additional_features_dict (additional_features_dict )
332332 )
333333 self .additional_data = self ._compile_additional_data (additional_data )
334+ self .plot ._tuning_round_digit ()
334335
335336 def _get_contributions_from_backend_or_user (self , x , contributions ):
336337 # Computing contributions using backend
@@ -1320,3 +1321,31 @@ def generate_report(
13201321 if rm_working_dir :
13211322 shutil .rmtree (working_dir )
13221323 raise e
1324+
1325+ def _local_pred (self , index , label = None ):
1326+ """
1327+ compute a local pred to display in local_plot
1328+ Parameters
1329+ ----------
1330+ index: string, int, float, ...
1331+ specify the row we want to pred
1332+ label: int (default: None)
1333+ Returns
1334+ -------
1335+ float: Predict or predict_proba value
1336+ """
1337+ if self ._case == "classification" :
1338+ if self .proba_values is not None :
1339+ value = self .proba_values .iloc [:, [label ]].loc [index ].values [0 ]
1340+ else :
1341+ value = None
1342+ elif self ._case == "regression" :
1343+ if self .y_pred is not None :
1344+ value = self .y_pred .loc [index ]
1345+ else :
1346+ value = self .model .predict (self .x_encoded .loc [[index ]])[0 ]
1347+
1348+ if isinstance (value , pd .Series ):
1349+ value = value .values [0 ]
1350+
1351+ return value
0 commit comments