Skip to content

Commit 8a50775

Browse files
Merge pull request #582 from guillaume-vignal/feature/smartplotter_simplification
SmartPlotter simplification by delegating each plot type to a separate function file
2 parents d03dca1 + e972df3 commit 8a50775

17 files changed

+3661
-2981
lines changed

shapash/explainer/smart_explainer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def compile(
331331
else self._compile_additional_features_dict(additional_features_dict)
332332
)
333333
self.additional_data = self._compile_additional_data(additional_data)
334+
self.plot._tuning_round_digit()
334335

335336
def _get_contributions_from_backend_or_user(self, x, contributions):
336337
# Computing contributions using backend
@@ -1320,3 +1321,31 @@ def generate_report(
13201321
if rm_working_dir:
13211322
shutil.rmtree(working_dir)
13221323
raise e
1324+
1325+
def _local_pred(self, index, label=None):
1326+
"""
1327+
compute a local pred to display in local_plot
1328+
Parameters
1329+
----------
1330+
index: string, int, float, ...
1331+
specify the row we want to pred
1332+
label: int (default: None)
1333+
Returns
1334+
-------
1335+
float: Predict or predict_proba value
1336+
"""
1337+
if self._case == "classification":
1338+
if self.proba_values is not None:
1339+
value = self.proba_values.iloc[:, [label]].loc[index].values[0]
1340+
else:
1341+
value = None
1342+
elif self._case == "regression":
1343+
if self.y_pred is not None:
1344+
value = self.y_pred.loc[index]
1345+
else:
1346+
value = self.model.predict(self.x_encoded.loc[[index]])[0]
1347+
1348+
if isinstance(value, pd.Series):
1349+
value = value.values[0]
1350+
1351+
return value

0 commit comments

Comments
 (0)