1+ from category_encoders import OrdinalEncoder
2+ import copy
13import itertools
24import matplotlib .pyplot as plt
35import numpy as np
46import pandas as pd
7+ from plotly import graph_objs as go
8+ from plotly .offline import plot
9+ from plotly .subplots import make_subplots
510from sklearn .manifold import MDS
11+
612from shapash import SmartExplainer
13+ from shapash .style .style_utils import colors_loading , select_palette , define_style
714
815
916class Consistency ():
1017
18+ def __init__ (self ):
19+ self ._palette_name = list (colors_loading ().keys ())[0 ]
20+ self ._style_dict = define_style (select_palette (colors_loading (), self ._palette_name ))
21+
22+ def tuning_colorscale (self , values ):
23+ """Adapts the color scale to the distribution of points
24+ Parameters
25+ ----------
26+ values: 1 column pd.DataFrame
27+ values whose quantiles must be calculated
28+ """
29+ desc_df = values .describe (percentiles = np .arange (0.1 , 1 , 0.1 ).tolist ())
30+ min_pred , max_init = list (desc_df .loc [['min' , 'max' ]].values )
31+ desc_pct_df = (desc_df .loc [~ desc_df .index .isin (['count' , 'mean' , 'std' ])] - min_pred ) / \
32+ (max_init - min_pred )
33+ color_scale = list (map (list , (zip (desc_pct_df .values .flatten (), self ._style_dict ["init_contrib_colorscale" ]))))
34+ return color_scale
35+
1136 def compile (self , x = None , model = None , preprocessing = None , contributions = None , methods = ["shap" , "acv" , "lime" ]):
1237 """If not provided, compute contributions according to provided methods (default are shap, acv, lime).
1338 If provided, check whether they respect the correct format:
@@ -36,10 +61,12 @@ def compile(self, x=None, model=None, preprocessing=None, contributions=None, me
3661 methods : list
3762 Methods used to compute contributions, by default ["shap", "acv", "lime"]
3863 """
64+ self .x = x
65+ self .preprocessing = preprocessing
3966 if contributions is None :
40- if (x is None ) or (model is None ):
67+ if (self . x is None ) or (model is None ):
4168 raise ValueError ('If no contributions are provided, parameters "x" and "model" must be defined' )
42- contributions = self .compute_contributions (x , model , methods , preprocessing )
69+ contributions = self .compute_contributions (self . x , model , methods , self . preprocessing )
4370 else :
4471 if not isinstance (contributions , dict ):
4572 raise ValueError ('Contributions must be a dictionary' )
@@ -49,8 +76,6 @@ def compile(self, x=None, model=None, preprocessing=None, contributions=None, me
4976 self .check_consistency_contributions (self .weights )
5077 self .index = self .weights [0 ].index
5178
52- self .weights = [weight .values for weight in self .weights ]
53-
5479 def compute_contributions (self , x , model , methods , preprocessing ):
5580 """
5681 Compute contributions based on specified methods
@@ -139,12 +164,12 @@ def consistency_plot(self, selection=None, max_features=20):
139164 """
140165 # Selection
141166 if selection is None :
142- weights = self .weights
167+ weights = [ weight . values for weight in self .weights ]
143168 elif isinstance (selection , list ):
144169 if len (selection ) == 1 :
145170 raise ValueError ('Selection must include multiple points' )
146171 else :
147- weights = [weight [selection ] for weight in self .weights ]
172+ weights = [weight . values [selection ] for weight in self .weights ]
148173 else :
149174 raise ValueError ('Parameter selection must be a list' )
150175
@@ -458,3 +483,238 @@ def plot_examples(self, method_1, method_2, l2, index, backend_name_1, backend_n
458483 axes [n ].set_yticks ([])
459484
460485 return fig
486+
487+ def pairwise_consistency_plot (self , methods , selection = None ,
488+ max_features = 10 , max_points = 100 , file_name = None , auto_open = False ):
489+ """The Pairwise_Consistency_plot compares the difference of 2 explainability methods across each feature and each data point,
490+ and plots the distribution of those differences.
491+
492+ This plot goes one step deeper than the consistency_plot which compares methods on a global level
493+ by expressing differences in terms of mean across the entire dataset.
494+ Not only we get an understanding of how differences are distributed across the dataset,
495+ but we can also identify whether there are patterns based on feature values,
496+ and understand when a method overestimates contributions compared to the other
497+
498+ Parameters
499+ ----------
500+ methods : list
501+ List of explainbility methods to compare
502+ selection: list
503+ Contains list of index, subset of the input DataFrame that we use
504+ for the compute of consitency statistics, by default None
505+ max_features: int, optional
506+ Maximum number of displayed features, by default 10
507+ max_points : int, optional
508+ Maximum number of displayed datapoints per feature, by default 100
509+ file_name: string, optional
510+ Specify the save path of html files. If it is not provided, no file will be saved.
511+ auto_open: bool
512+ open automatically the plot, by default False
513+
514+
515+ Returns
516+ -------
517+ figure
518+ """
519+ if self .x is None :
520+ raise ValueError ('x must be defined in the compile to display the plot' )
521+ if not isinstance (self .x , pd .DataFrame ):
522+ raise ValueError ('x must be a pandas DataFrame' )
523+ if len (methods ) != 2 :
524+ raise ValueError ('Choose 2 methods among methods of the contributions' )
525+
526+ # Select contributions of input methods
527+ pair_indices = [self .methods .index (x ) for x in methods ]
528+ pair_weights = [self .weights [i ] for i in pair_indices ]
529+
530+ # Selection
531+ if selection is None :
532+ ind_max_points = self .x .sample (min (max_points , len (self .x ))).index
533+ weights = [weight .iloc [ind_max_points ] for weight in pair_weights ]
534+ x = self .x .iloc [ind_max_points ]
535+ elif isinstance (selection , list ):
536+ if len (selection ) == 1 :
537+ raise ValueError ('Selection must include multiple points' )
538+ else :
539+ weights = [weight .iloc [selection ] for weight in pair_weights ]
540+ x = self .x .iloc [selection ]
541+ else :
542+ raise ValueError ('Parameter selection must be a list' )
543+
544+ # Remove constant columns
545+ const_cols = x .loc [:, x .apply (pd .Series .nunique ) == 1 ]
546+ x = x .drop (const_cols , axis = 1 )
547+ weights = [weight .drop (const_cols , axis = 1 ) for weight in weights ]
548+
549+ # Only keep features based on largest mean of absolute values
550+ mean_contributions = np .mean (np .abs (pd .concat (weights )))
551+ top_features = np .flip (mean_contributions .sort_values (ascending = False )[:max_features ].keys ())
552+
553+ fig = self .plot_pairwise_consistency (weights , x , top_features , methods , file_name , auto_open )
554+
555+ return fig
556+
557+ def plot_pairwise_consistency (self , weights , x , top_features , methods , file_name , auto_open ):
558+ """Plot the main graph displaying distances between methods across each feature and data point
559+
560+ Parameters
561+ ----------
562+ weights : list
563+ List of 2 dataframes containing contributions for the selected points
564+ x : DataFrame
565+ Original input data filtered on selected points
566+ top_features : array
567+ Top features to display ordered by mean of absolute contributions across all the selected points
568+ methods : list
569+ List of explainbility methods to compare
570+ file_name: string
571+ Specify the save path of html files. If it is not provided, no file will be saved.
572+ auto_open: bool
573+ open automatically the plot
574+
575+ Returns
576+ -------
577+ figure
578+ """
579+ # Look for existing OrdinalEncoder. If none, create one for string columns
580+ if isinstance (self .preprocessing , OrdinalEncoder ):
581+ encoder = self .preprocessing
582+ else :
583+ categorical_features = [col for col in x .columns if x [col ].dtype == 'object' ]
584+ encoder = OrdinalEncoder (cols = categorical_features ,
585+ handle_unknown = 'ignore' ,
586+ return_df = True ).fit (x )
587+ x = encoder .transform (x )
588+
589+ xaxis_title = "Difference of contributions between the 2 methods" \
590+ + f"<span style='font-size: 12px;'><br />{ methods [0 ]} - { methods [1 ]} </span>"
591+ yaxis_title = "Top features<span style='font-size: 12px;'><br />(Ordered by mean of absolute contributions)</span>"
592+
593+ fig = make_subplots (specs = [[{"secondary_y" : True }]])
594+
595+ # Plot the distribution
596+
597+ for i , c in enumerate (top_features ):
598+
599+ switch = False
600+ if c in encoder .cols :
601+
602+ switch = True
603+
604+ mapping = encoder .mapping [encoder .cols .index (c )]["mapping" ]
605+ inverse_mapping = {v : k for k , v in mapping .to_dict ().items ()}
606+ feature_value = x [c ].map (inverse_mapping )
607+
608+ hv_text = [f"<b>Feature value</b>: { i } <br><b>{ methods [0 ]} </b>: { j } <br><b>{ methods [1 ]} </b>: { k } <br><b>Diff</b>: { l } "
609+ for i , j , k , l in zip (feature_value if switch else x [c ].round (3 ),
610+ weights [0 ][c ].round (2 ),
611+ weights [1 ][c ].round (2 ),
612+ (weights [0 ][c ] - weights [1 ][c ]).round (2 ))]
613+
614+ fig .add_trace (
615+ go .Violin (
616+ x = (weights [0 ][c ] - weights [1 ][c ]).values ,
617+ name = c ,
618+ points = False ,
619+ fillcolor = "rgba(255, 0, 0, 0.1)" ,
620+ line = {"color" : "black" , "width" : 0.5 },
621+ showlegend = False ,
622+ ), secondary_y = False
623+ )
624+
625+ fig .add_trace (
626+ go .Scatter (
627+ x = (weights [0 ][c ] - weights [1 ][c ]).values ,
628+ y = len (x )* [i ] + np .random .normal (0 , 0.1 , len (x )),
629+ mode = 'markers' ,
630+ marker = {"color" : x [c ].values ,
631+ "colorscale" : self .tuning_colorscale (x [c ]),
632+ "opacity" : 0.7 },
633+ name = c ,
634+ text = len (x )* [c ],
635+ hovertext = hv_text ,
636+ hovertemplate = "<b>%{text}</b><br><br>" +
637+ "%{hovertext}<br>" +
638+ "<extra></extra>" ,
639+ showlegend = False ,
640+ ), secondary_y = True
641+ )
642+
643+ # Dummy invisible plot to add the color scale
644+ colorbar_trace = go .Scatter (
645+ x = [None ],
646+ y = [None ],
647+ mode = "markers" ,
648+ marker = dict (
649+ size = 1 ,
650+ color = [x .min (), x .max ()],
651+ colorscale = self .tuning_colorscale (pd .Series (np .linspace (x .min ().min (), x .max ().max (), 10 ))),
652+ colorbar = dict (thickness = 20 ,
653+ lenmode = "pixels" ,
654+ len = 400 ,
655+ yanchor = "top" ,
656+ y = 1.1 ,
657+ ypad = 20 ,
658+ title = "Feature values" ,
659+ tickvals = [x .min ().min (), x .max ().max ()],
660+ ticktext = ["Low" , "High" ]),
661+ showscale = True ,
662+ ),
663+ hoverinfo = "none" ,
664+ showlegend = False ,
665+ )
666+
667+ fig .add_trace (colorbar_trace )
668+
669+ self ._update_pairwise_consistency_fig (fig = fig ,
670+ top_features = top_features ,
671+ xaxis_title = xaxis_title ,
672+ yaxis_title = yaxis_title ,
673+ file_name = file_name ,
674+ auto_open = auto_open )
675+
676+ return fig
677+
678+ def _update_pairwise_consistency_fig (self , fig , top_features , xaxis_title , yaxis_title , file_name , auto_open ):
679+ """Function used for the pairwise_consistency_plot to update the layout of the plotly figure.
680+
681+ Parameters
682+ ----------
683+ fig : figure
684+ Plotly figure
685+ top_features : array
686+ Top features to display ordered by mean of absolute contributions across all the selected points
687+ xaxis_title : str
688+ Title for the x-axis
689+ yaxis_title : str
690+ Title for the y-axis
691+ file_name: string
692+ Specify the save path of html files. If it is not provided, no file will be saved.
693+ auto_open: bool
694+ open automatically the plot
695+ """
696+ title = "Pairwise comparison of Consistency:"
697+ title += "<span style='font-size: 16px;'>\
698+ <br />How are differences in contributions distributed across features?</span>"
699+ dict_t = copy .deepcopy (self ._style_dict ["dict_title_stability" ])
700+ dict_xaxis = copy .deepcopy (self ._style_dict ["dict_xaxis" ])
701+ dict_yaxis = copy .deepcopy (self ._style_dict ["dict_yaxis" ])
702+ dict_xaxis ['text' ] = xaxis_title
703+ dict_yaxis ['text' ] = yaxis_title
704+ dict_t ['text' ] = title
705+
706+ fig .layout .yaxis .update (showticklabels = True )
707+ fig .layout .yaxis2 .update (showticklabels = False )
708+ fig .update_layout (template = "none" ,
709+ title = dict_t ,
710+ xaxis_title = dict_xaxis ,
711+ yaxis_title = dict_yaxis ,
712+ yaxis = dict (range = [- 0.7 , len (top_features )- 0.3 ]),
713+ yaxis2 = dict (range = [- 0.7 , len (top_features )- 0.3 ]),
714+ height = max (500 , 40 * len (top_features )))
715+
716+ fig .update_yaxes (automargin = True , zeroline = False )
717+ fig .update_xaxes (automargin = True )
718+
719+ if file_name is not None :
720+ plot (fig , filename = file_name , auto_open = auto_open )
0 commit comments