2121from shapash .manipulation .filters import combine_masks
2222from shapash .manipulation .mask import init_mask
2323from shapash .manipulation .mask import compute_masked_contributions
24- from shapash .manipulation .summarize import summarize
24+ from shapash .manipulation .summarize import summarize , create_grouped_features_values , group_contributions
2525from shapash .decomposition .contributions import rank_contributions , assign_contributions
2626from shapash .utils .columntransformer_backend import columntransformer
2727import copy
@@ -100,6 +100,7 @@ def __init__(self, features_dict, model,
100100 columns_dict , explainer , features_types ,
101101 label_dict = None , preprocessing = None ,
102102 postprocessing = None ,
103+ features_groups = None ,
103104 mask_params = {"features_to_hide" : None ,
104105 "threshold" : None ,
105106 "positive" : None ,
@@ -130,10 +131,11 @@ def __init__(self, features_dict, model,
130131 self .mask_params = mask_params
131132 self .check_mask_params ()
132133 self .postprocessing = postprocessing
134+ self .features_groups = features_groups
133135 list_preprocessing = preprocessing_tolist (self .preprocessing )
134136 check_consistency_model_features (self .features_dict , self .model , self .columns_dict ,
135137 self .features_types , self .mask_params , self .preprocessing ,
136- self .postprocessing , list_preprocessing )
138+ self .postprocessing , list_preprocessing , self . features_groups )
137139 check_consistency_model_label (self .columns_dict , self .label_dict )
138140 self ._drop_option = check_preprocessing_options (columns_dict , features_dict , preprocessing , list_preprocessing )
139141
@@ -217,9 +219,34 @@ def add_input(self, x=None, ypred=None, contributions=None):
217219 self .data ["ypred_init" ] = self .check_ypred (ypred )
218220
219221 if contributions is not None :
220- self .data ["ypred" ], self .data ["contributions" ] = self .compute_contributions (contributions = contributions )
222+ self .data ["ypred" ], self .data ["contributions" ] = self .compute_contributions (
223+ contributions = contributions ,
224+ use_groups = False
225+ )
221226 else :
222- self .data ["ypred" ], self .data ["contributions" ] = self .compute_contributions ()
227+ self .data ["ypred" ], self .data ["contributions" ] = self .compute_contributions (use_groups = False )
228+
229+ if self .features_groups is not None :
230+ self ._add_groups_input ()
231+
232+ def _add_groups_input (self ):
233+ """
234+ Compute groups of features values, contributions the same way as add_input method
235+ and stores it in data_groups attribute
236+ """
237+ self .data_groups = dict ()
238+ self .data_groups ['x_postprocessed' ] = create_grouped_features_values (x_pred = self .data ["x_postprocessed" ],
239+ x_init = self .data ["x_preprocessed" ],
240+ preprocessing = self .preprocessing ,
241+ features_groups = self .features_groups ,
242+ features_dict = self .features_dict ,
243+ how = 'dict_of_values' )
244+ self .data_groups ['ypred' ] = self .data ["ypred" ]
245+ self .data_groups ['contributions' ] = group_contributions (
246+ contributions = self .data ['contributions' ],
247+ features_groups = self .features_groups
248+ )
249+
223250
224251 def check_dataset_type (self , x = None ):
225252 """
@@ -431,7 +458,7 @@ def predict_proba(self):
431458 """
432459 return predict_proba (self .model , self .data ["x_preprocessed" ], self ._classes )
433460
434- def compute_contributions (self , contributions = None ):
461+ def compute_contributions (self , contributions = None , use_groups = None ):
435462 """
436463 The compute_contributions compute the contributions associated to data ypred specified.
437464 Need a data ypred specified in an add_input to display detail_contributions.
@@ -440,6 +467,8 @@ def compute_contributions(self, contributions=None):
440467 -------
441468 contributions : object (optional)
442469 Local contributions, or list of local contributions.
470+ use_groups : bool (optional)
471+ Whether or not to compute groups of features contributions.
443472
444473 Returns
445474 -------
@@ -449,6 +478,8 @@ def compute_contributions(self, contributions=None):
449478 ypred data with right probabilities associated.
450479
451480 """
481+ use_groups = True if (use_groups is not False and self .features_groups is not None ) else False
482+
452483 if not hasattr (self , "data" ):
453484 raise ValueError ("add_input method must be called at least once." )
454485 if self .data ["x" ] is None :
@@ -475,9 +506,12 @@ def compute_contributions(self, contributions=None):
475506 y_pred , match_contrib = keep_right_contributions (self .data ["ypred_init" ], contributions ,
476507 self ._case , self ._classes ,
477508 self .label_dict , proba_values )
509+ if use_groups :
510+ match_contrib = group_contributions (match_contrib , features_groups = self .features_groups )
511+
478512 return y_pred , match_contrib
479513
480- def detail_contributions (self , contributions = None ):
514+ def detail_contributions (self , contributions = None , use_groups = None ):
481515 """
482516 The detail_contributions method associates the right contributions with the right data predicted.
483517 (with ypred specified in add_input or computed automatically)
@@ -486,6 +520,8 @@ def detail_contributions(self, contributions=None):
486520 -------
487521 contributions : object (optional)
488522 Local contributions, or list of local contributions.
523+ use_groups : bool (optional)
524+ Whether or not to compute groups of features contributions.
489525
490526 Returns
491527 -------
@@ -499,7 +535,7 @@ def detail_contributions(self, contributions=None):
499535 >>> predictor.detail_contributions()
500536
501537 """
502- y_pred , detail_contrib = self .compute_contributions (contributions = contributions )
538+ y_pred , detail_contrib = self .compute_contributions (contributions = contributions , use_groups = use_groups )
503539 return pd .concat ([y_pred , detail_contrib ], axis = 1 )
504540
505541 def apply_preprocessing_for_contributions (self , contributions , preprocessing = None ):
@@ -593,7 +629,7 @@ def filter(self):
593629 self .mask
594630 )
595631
596- def summarize (self ):
632+ def summarize (self , use_groups = None ):
597633 """
598634 The summarize method allows to display the summary of local explainability.
599635 This method can be configured with modify_mask method to summarize the explainability to suit needs.
@@ -606,6 +642,11 @@ def summarize(self):
606642 - the right probabilities from predict_proba associated to the right predicted values
607643 - the right contributions ranked and filtered as specify with modify_mask method
608644
645+ Parameters
646+ ----------
647+ use_groups : bool (optional)
648+ Whether or not to compute groups of features contributions.
649+
609650 Returns
610651 -------
611652 pandas.DataFrame
@@ -629,39 +670,47 @@ def summarize(self):
629670 2 0 0.543308 Sex 2.0 -0.486667
630671 """
631672 # data is needed : add_input() method must be called at least once
673+ use_groups = True if (use_groups is not False and self .features_groups is not None ) else False
632674
633675 if not hasattr (self , "data" ):
634676 raise ValueError ("You have to specify dataset x and y_pred arguments. Please use add_input() method." )
635677
678+ if use_groups is True :
679+ data = self .data_groups
680+ else :
681+ data = self .data
682+
636683 if self ._drop_option is not None :
637- x_preprocessed = self .data ["x_postprocessed" ][self ._drop_option ["columns_dict_op" ].values ()]
638- columns_dict = self ._drop_option ["columns_dict_op" ]
639- features_dict = self ._drop_option ["features_dict_op" ]
684+ columns_to_keep = [x for x in self ._drop_option ["columns_dict_op" ].values ()
685+ if x in data ["x_postprocessed" ].columns ]
686+ if use_groups :
687+ columns_to_keep += list (self .features_groups .keys ())
688+ x_preprocessed = data ["x_postprocessed" ][columns_to_keep ]
640689 else :
641- x_preprocessed = self .data ["x_postprocessed" ]
642- columns_dict = self .columns_dict
643- features_dict = self .features_dict
690+ x_preprocessed = data ["x_postprocessed" ]
644691
692+ columns_dict = {i : col for i , col in enumerate (x_preprocessed .columns )}
693+ features_dict = {k : v for k , v in self .features_dict .items () if k in x_preprocessed .columns }
645694
646695 self .summary = assign_contributions (
647696 rank_contributions (
648- self . data ["contributions" ],
697+ data ["contributions" ],
649698 x_preprocessed
650699 )
651700 )
652701 # Apply filter method with mask_params attributes parameters
653702 self .filter ()
654703
655704 # Summarize information
656- self . data ['summary' ] = summarize (self .summary ['contrib_sorted' ],
705+ data ['summary' ] = summarize (self .summary ['contrib_sorted' ],
657706 self .summary ['var_dict' ],
658707 self .summary ['x_sorted' ],
659708 self .mask ,
660709 columns_dict ,
661710 features_dict )
662711
663712 # Matching with y_pred
664- return pd .concat ([self . data ["ypred" ], self . data ['summary' ]], axis = 1 )
713+ return pd .concat ([data ["ypred" ], data ['summary' ]], axis = 1 )
665714
666715 def modify_mask (
667716 self ,
@@ -804,5 +853,6 @@ def to_smartexplainer(self):
804853 explainer = self .explainer ,
805854 y_pred = copy .deepcopy (self .data ["ypred_init" ]),
806855 preprocessing = self .preprocessing ,
807- postprocessing = self .postprocessing )
856+ postprocessing = self .postprocessing ,
857+ features_groups = self .features_groups )
808858 return xpl
0 commit comments