Skip to content

Commit 21e160c

Browse files
authored
Merge pull request #212 from ThibaudReal/feature/group_features_predictor
Feature/group features smart predictor
2 parents 3c72908 + ad1f490 commit 21e160c

File tree

9 files changed

+1603
-62
lines changed

9 files changed

+1603
-62
lines changed

docs/tutorials/tuto-common01-groups_of_features.rst

Lines changed: 138 additions & 0 deletions
Large diffs are not rendered by default.

shapash/explainer/smart_explainer.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,8 @@ def _compile_features_groups(self, features_groups):
259259
self.x_pred_groups = create_grouped_features_values(x_pred=self.x_pred, x_init=self.x_init,
260260
preprocessing=self.preprocessing,
261261
features_groups=self.features_groups,
262-
how='tsne')
262+
features_dict=self.features_dict,
263+
how='dict_of_values')
263264
# Compute data attribute for groups of features
264265
self.data_groups = self.state.assign_contributions(
265266
self.state.rank_contributions(
@@ -818,7 +819,8 @@ def to_pandas(
818819
threshold=None,
819820
positive=None,
820821
max_contrib=None,
821-
proba=False
822+
proba=False,
823+
use_groups=None
822824
):
823825
"""
824826
The to_pandas method allows to export the summary of local explainability.
@@ -848,6 +850,9 @@ def to_pandas(
848850
Number of contributions to show in the pandas df
849851
proba : bool, optional (default: False)
850852
adding proba in output df
853+
use_groups : bool (optional)
854+
Whether or not to use groups of features contributions (only available if features_groups
855+
parameter was not empty when calling compile method).
851856
852857
Returns
853858
-------
@@ -864,6 +869,11 @@ def to_pandas(
864869
1 3 0.628911 Sex 2.0 0.585475 Pclass 1.0 0.370504
865870
2 0 0.543308 Sex 2.0 -0.486667 Pclass 3.0 0.255072
866871
"""
872+
use_groups = True if (use_groups is not False and self.features_groups is not None) else False
873+
if use_groups:
874+
data = self.data_groups
875+
else:
876+
data = self.data
867877

868878
# Classification: y_pred is needed
869879
if self.y_pred is None:
@@ -873,21 +883,34 @@ def to_pandas(
873883

874884
# Apply filter method if necessary
875885
if all(var is None for var in [features_to_hide, threshold, positive, max_contrib]) \
876-
and hasattr(self, 'mask_params'):
886+
and hasattr(self, 'mask_params') \
887+
and (
888+
# if the already computed mask does not have the right shape (this can happen when
889+
# we use groups of features once and then use method without groups)
890+
(isinstance(data['contrib_sorted'], pd.DataFrame)
891+
and len(data["contrib_sorted"].columns) == len(self.mask.columns))
892+
or
893+
(isinstance(data['contrib_sorted'], list)
894+
and len(data["contrib_sorted"][0].columns) == len(self.mask[0].columns))
895+
):
877896
print('to_pandas params: ' + str(self.mask_params))
878897
else:
879898
self.filter(features_to_hide=features_to_hide,
880899
threshold=threshold,
881900
positive=positive,
882-
max_contrib=max_contrib)
883-
901+
max_contrib=max_contrib,
902+
display_groups=use_groups)
903+
if use_groups:
904+
columns_dict = {i: col for i, col in enumerate(self.x_pred_groups.columns)}
905+
else:
906+
columns_dict = self.columns_dict
884907
# Summarize information
885-
self.data['summary'] = self.state.summarize(
886-
self.data['contrib_sorted'],
887-
self.data['var_dict'],
888-
self.data['x_sorted'],
908+
data['summary'] = self.state.summarize(
909+
data['contrib_sorted'],
910+
data['var_dict'],
911+
data['x_sorted'],
889912
self.mask,
890-
self.columns_dict,
913+
columns_dict,
891914
self.features_dict
892915
)
893916
# Matching with y_pred
@@ -897,7 +920,7 @@ def to_pandas(
897920
else:
898921
proba_values = None
899922

900-
y_pred, summary = keep_right_contributions(self.y_pred, self.data['summary'],
923+
y_pred, summary = keep_right_contributions(self.y_pred, data['summary'],
901924
self._case, self._classes,
902925
self.label_dict, proba_values)
903926

@@ -1025,7 +1048,7 @@ def to_smartpredictor(self):
10251048
self.features_types = {features: str(self.x_pred[features].dtypes) for features in self.x_pred.columns}
10261049

10271050
listattributes = ["features_dict", "model", "columns_dict", "explainer", "features_types",
1028-
"label_dict", "preprocessing", "postprocessing"]
1051+
"label_dict", "preprocessing", "postprocessing", "features_groups"]
10291052

10301053
params_smartpredictor = [self.check_attributes(attribute) for attribute in listattributes]
10311054

shapash/explainer/smart_plotter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1240,7 +1240,8 @@ def contribution_plot(self,
12401240

12411241
if col_is_group:
12421242
feature_values = project_feature_values_1d(feature_values, col, self.explainer.x_pred,
1243-
self.explainer.x_init, self.explainer.preprocessing)
1243+
self.explainer.x_init, self.explainer.preprocessing,
1244+
features_dict=self.explainer.features_dict)
12441245
contrib = subcontrib.loc[list_ind, col].to_frame()
12451246
if self.explainer.features_imp is None:
12461247
self.explainer.compute_features_import()

shapash/explainer/smart_predictor.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from shapash.manipulation.filters import combine_masks
2222
from shapash.manipulation.mask import init_mask
2323
from shapash.manipulation.mask import compute_masked_contributions
24-
from shapash.manipulation.summarize import summarize
24+
from shapash.manipulation.summarize import summarize, create_grouped_features_values, group_contributions
2525
from shapash.decomposition.contributions import rank_contributions, assign_contributions
2626
from shapash.utils.columntransformer_backend import columntransformer
2727
import copy
@@ -100,6 +100,7 @@ def __init__(self, features_dict, model,
100100
columns_dict, explainer, features_types,
101101
label_dict=None, preprocessing=None,
102102
postprocessing=None,
103+
features_groups=None,
103104
mask_params = {"features_to_hide": None,
104105
"threshold": None,
105106
"positive": None,
@@ -130,10 +131,11 @@ def __init__(self, features_dict, model,
130131
self.mask_params = mask_params
131132
self.check_mask_params()
132133
self.postprocessing = postprocessing
134+
self.features_groups = features_groups
133135
list_preprocessing = preprocessing_tolist(self.preprocessing)
134136
check_consistency_model_features(self.features_dict, self.model, self.columns_dict,
135137
self.features_types, self.mask_params, self.preprocessing,
136-
self.postprocessing, list_preprocessing)
138+
self.postprocessing, list_preprocessing, self.features_groups)
137139
check_consistency_model_label(self.columns_dict, self.label_dict)
138140
self._drop_option = check_preprocessing_options(columns_dict, features_dict, preprocessing, list_preprocessing)
139141

@@ -217,9 +219,34 @@ def add_input(self, x=None, ypred=None, contributions=None):
217219
self.data["ypred_init"] = self.check_ypred(ypred)
218220

219221
if contributions is not None:
220-
self.data["ypred"], self.data["contributions"] = self.compute_contributions(contributions=contributions)
222+
self.data["ypred"], self.data["contributions"] = self.compute_contributions(
223+
contributions=contributions,
224+
use_groups=False
225+
)
221226
else:
222-
self.data["ypred"], self.data["contributions"] = self.compute_contributions()
227+
self.data["ypred"], self.data["contributions"] = self.compute_contributions(use_groups=False)
228+
229+
if self.features_groups is not None:
230+
self._add_groups_input()
231+
232+
def _add_groups_input(self):
233+
"""
234+
Compute groups of features values, contributions the same way as add_input method
235+
and stores it in data_groups attribute
236+
"""
237+
self.data_groups = dict()
238+
self.data_groups['x_postprocessed'] = create_grouped_features_values(x_pred=self.data["x_postprocessed"],
239+
x_init=self.data["x_preprocessed"],
240+
preprocessing=self.preprocessing,
241+
features_groups=self.features_groups,
242+
features_dict=self.features_dict,
243+
how='dict_of_values')
244+
self.data_groups['ypred'] = self.data["ypred"]
245+
self.data_groups['contributions'] = group_contributions(
246+
contributions=self.data['contributions'],
247+
features_groups=self.features_groups
248+
)
249+
223250

224251
def check_dataset_type(self, x=None):
225252
"""
@@ -431,7 +458,7 @@ def predict_proba(self):
431458
"""
432459
return predict_proba(self.model, self.data["x_preprocessed"], self._classes)
433460

434-
def compute_contributions(self, contributions=None):
461+
def compute_contributions(self, contributions=None, use_groups=None):
435462
"""
436463
The compute_contributions compute the contributions associated to data ypred specified.
437464
Need a data ypred specified in an add_input to display detail_contributions.
@@ -440,6 +467,8 @@ def compute_contributions(self, contributions=None):
440467
-------
441468
contributions : object (optional)
442469
Local contributions, or list of local contributions.
470+
use_groups : bool (optional)
471+
Whether or not to compute groups of features contributions.
443472
444473
Returns
445474
-------
@@ -449,6 +478,8 @@ def compute_contributions(self, contributions=None):
449478
ypred data with right probabilities associated.
450479
451480
"""
481+
use_groups = True if (use_groups is not False and self.features_groups is not None) else False
482+
452483
if not hasattr(self, "data"):
453484
raise ValueError("add_input method must be called at least once.")
454485
if self.data["x"] is None:
@@ -475,9 +506,12 @@ def compute_contributions(self, contributions=None):
475506
y_pred, match_contrib = keep_right_contributions(self.data["ypred_init"], contributions,
476507
self._case, self._classes,
477508
self.label_dict, proba_values)
509+
if use_groups:
510+
match_contrib = group_contributions(match_contrib, features_groups=self.features_groups)
511+
478512
return y_pred, match_contrib
479513

480-
def detail_contributions(self, contributions=None):
514+
def detail_contributions(self, contributions=None, use_groups=None):
481515
"""
482516
The detail_contributions method associates the right contributions with the right data predicted.
483517
(with ypred specified in add_input or computed automatically)
@@ -486,6 +520,8 @@ def detail_contributions(self, contributions=None):
486520
-------
487521
contributions : object (optional)
488522
Local contributions, or list of local contributions.
523+
use_groups : bool (optional)
524+
Whether or not to compute groups of features contributions.
489525
490526
Returns
491527
-------
@@ -499,7 +535,7 @@ def detail_contributions(self, contributions=None):
499535
>>> predictor.detail_contributions()
500536
501537
"""
502-
y_pred, detail_contrib = self.compute_contributions(contributions=contributions)
538+
y_pred, detail_contrib = self.compute_contributions(contributions=contributions, use_groups=use_groups)
503539
return pd.concat([y_pred, detail_contrib], axis=1)
504540

505541
def apply_preprocessing_for_contributions(self, contributions, preprocessing=None):
@@ -593,7 +629,7 @@ def filter(self):
593629
self.mask
594630
)
595631

596-
def summarize(self):
632+
def summarize(self, use_groups=None):
597633
"""
598634
The summarize method allows to display the summary of local explainability.
599635
This method can be configured with modify_mask method to summarize the explainability to suit needs.
@@ -606,6 +642,11 @@ def summarize(self):
606642
- the right probabilities from predict_proba associated to the right predicted values
607643
- the right contributions ranked and filtered as specify with modify_mask method
608644
645+
Parameters
646+
----------
647+
use_groups : bool (optional)
648+
Whether or not to compute groups of features contributions.
649+
609650
Returns
610651
-------
611652
pandas.DataFrame
@@ -629,39 +670,47 @@ def summarize(self):
629670
2 0 0.543308 Sex 2.0 -0.486667
630671
"""
631672
# data is needed : add_input() method must be called at least once
673+
use_groups = True if (use_groups is not False and self.features_groups is not None) else False
632674

633675
if not hasattr(self, "data"):
634676
raise ValueError("You have to specify dataset x and y_pred arguments. Please use add_input() method.")
635677

678+
if use_groups is True:
679+
data = self.data_groups
680+
else:
681+
data = self.data
682+
636683
if self._drop_option is not None:
637-
x_preprocessed = self.data["x_postprocessed"][self._drop_option["columns_dict_op"].values()]
638-
columns_dict =self._drop_option["columns_dict_op"]
639-
features_dict = self._drop_option["features_dict_op"]
684+
columns_to_keep = [x for x in self._drop_option["columns_dict_op"].values()
685+
if x in data["x_postprocessed"].columns]
686+
if use_groups:
687+
columns_to_keep += list(self.features_groups.keys())
688+
x_preprocessed = data["x_postprocessed"][columns_to_keep]
640689
else:
641-
x_preprocessed = self.data["x_postprocessed"]
642-
columns_dict = self.columns_dict
643-
features_dict = self.features_dict
690+
x_preprocessed = data["x_postprocessed"]
644691

692+
columns_dict = {i: col for i, col in enumerate(x_preprocessed.columns)}
693+
features_dict = {k: v for k, v in self.features_dict.items() if k in x_preprocessed.columns}
645694

646695
self.summary = assign_contributions(
647696
rank_contributions(
648-
self.data["contributions"],
697+
data["contributions"],
649698
x_preprocessed
650699
)
651700
)
652701
# Apply filter method with mask_params attributes parameters
653702
self.filter()
654703

655704
# Summarize information
656-
self.data['summary'] = summarize(self.summary['contrib_sorted'],
705+
data['summary'] = summarize(self.summary['contrib_sorted'],
657706
self.summary['var_dict'],
658707
self.summary['x_sorted'],
659708
self.mask,
660709
columns_dict,
661710
features_dict)
662711

663712
# Matching with y_pred
664-
return pd.concat([self.data["ypred"], self.data['summary']], axis=1)
713+
return pd.concat([data["ypred"], data['summary']], axis=1)
665714

666715
def modify_mask(
667716
self,
@@ -804,5 +853,6 @@ def to_smartexplainer(self):
804853
explainer=self.explainer,
805854
y_pred=copy.deepcopy(self.data["ypred_init"]),
806855
preprocessing=self.preprocessing,
807-
postprocessing=self.postprocessing)
856+
postprocessing=self.postprocessing,
857+
features_groups=self.features_groups)
808858
return xpl

0 commit comments

Comments
 (0)