Skip to content

Commit 71888dc

Browse files
authored
Merge branch 'master' into feature/replace_3_6_by_3_10
2 parents 4c296da + b6abb31 commit 71888dc

File tree

10 files changed

+426
-79
lines changed

10 files changed

+426
-79
lines changed

requirements.dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ pip==21.3.1
22
numpy>=1.18.0
33
dash==2.3.1
44
catboost>=1.0.1
5-
category-encoders==2.1.0
5+
category-encoders>=2.2.2
66
dash-bootstrap-components==1.1.0
77
dash-core-components==2.0.0
88
dash-daq==0.5.0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
extras['lightgbm'] = ['lightgbm>=2.3.0']
5353
extras['catboost'] = ['catboost>=1.0.1']
5454
extras['scikit-learn'] = ['scikit-learn>=0.23.0']
55-
extras['category_encoders'] = ['category_encoders==2.2.2']
55+
extras['category_encoders'] = ['category_encoders>=2.2.2']
5656
extras['acv'] = ['acv-exp==1.2.0']
5757
extras['lime'] = ['lime>=0.2.0.0']
5858

shapash/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
VERSION = (2, 0, 1)
1+
VERSION = (2, 0, 2)
22

33
__version__ = ".".join(map(str, VERSION))

shapash/explainer/consistency.py

Lines changed: 266 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,38 @@
1+
from category_encoders import OrdinalEncoder
2+
import copy
13
import itertools
24
import matplotlib.pyplot as plt
35
import numpy as np
46
import pandas as pd
7+
from plotly import graph_objs as go
8+
from plotly.offline import plot
9+
from plotly.subplots import make_subplots
510
from sklearn.manifold import MDS
11+
612
from shapash import SmartExplainer
13+
from shapash.style.style_utils import colors_loading, select_palette, define_style
714

815

916
class Consistency():
1017

18+
def __init__(self):
19+
self._palette_name = list(colors_loading().keys())[0]
20+
self._style_dict = define_style(select_palette(colors_loading(), self._palette_name))
21+
22+
def tuning_colorscale(self, values):
23+
"""Adapts the color scale to the distribution of points
24+
Parameters
25+
----------
26+
values: 1 column pd.DataFrame
27+
values ​​whose quantiles must be calculated
28+
"""
29+
desc_df = values.describe(percentiles=np.arange(0.1, 1, 0.1).tolist())
30+
min_pred, max_init = list(desc_df.loc[['min', 'max']].values)
31+
desc_pct_df = (desc_df.loc[~desc_df.index.isin(['count', 'mean', 'std'])] - min_pred) / \
32+
(max_init - min_pred)
33+
color_scale = list(map(list, (zip(desc_pct_df.values.flatten(), self._style_dict["init_contrib_colorscale"]))))
34+
return color_scale
35+
1136
def compile(self, x=None, model=None, preprocessing=None, contributions=None, methods=["shap", "acv", "lime"]):
1237
"""If not provided, compute contributions according to provided methods (default are shap, acv, lime).
1338
If provided, check whether they respect the correct format:
@@ -36,10 +61,12 @@ def compile(self, x=None, model=None, preprocessing=None, contributions=None, me
3661
methods : list
3762
Methods used to compute contributions, by default ["shap", "acv", "lime"]
3863
"""
64+
self.x = x
65+
self.preprocessing = preprocessing
3966
if contributions is None:
40-
if (x is None) or (model is None):
67+
if (self.x is None) or (model is None):
4168
raise ValueError('If no contributions are provided, parameters "x" and "model" must be defined')
42-
contributions = self.compute_contributions(x, model, methods, preprocessing)
69+
contributions = self.compute_contributions(self.x, model, methods, self.preprocessing)
4370
else:
4471
if not isinstance(contributions, dict):
4572
raise ValueError('Contributions must be a dictionary')
@@ -49,8 +76,6 @@ def compile(self, x=None, model=None, preprocessing=None, contributions=None, me
4976
self.check_consistency_contributions(self.weights)
5077
self.index = self.weights[0].index
5178

52-
self.weights = [weight.values for weight in self.weights]
53-
5479
def compute_contributions(self, x, model, methods, preprocessing):
5580
"""
5681
Compute contributions based on specified methods
@@ -139,12 +164,12 @@ def consistency_plot(self, selection=None, max_features=20):
139164
"""
140165
# Selection
141166
if selection is None:
142-
weights = self.weights
167+
weights = [weight.values for weight in self.weights]
143168
elif isinstance(selection, list):
144169
if len(selection) == 1:
145170
raise ValueError('Selection must include multiple points')
146171
else:
147-
weights = [weight[selection] for weight in self.weights]
172+
weights = [weight.values[selection] for weight in self.weights]
148173
else:
149174
raise ValueError('Parameter selection must be a list')
150175

@@ -458,3 +483,238 @@ def plot_examples(self, method_1, method_2, l2, index, backend_name_1, backend_n
458483
axes[n].set_yticks([])
459484

460485
return fig
486+
487+
def pairwise_consistency_plot(self, methods, selection=None,
488+
max_features=10, max_points=100, file_name=None, auto_open=False):
489+
"""The Pairwise_Consistency_plot compares the difference of 2 explainability methods across each feature and each data point,
490+
and plots the distribution of those differences.
491+
492+
This plot goes one step deeper than the consistency_plot which compares methods on a global level
493+
by expressing differences in terms of mean across the entire dataset.
494+
Not only we get an understanding of how differences are distributed across the dataset,
495+
but we can also identify whether there are patterns based on feature values,
496+
and understand when a method overestimates contributions compared to the other
497+
498+
Parameters
499+
----------
500+
methods : list
501+
List of explainbility methods to compare
502+
selection: list
503+
Contains list of index, subset of the input DataFrame that we use
504+
for the compute of consitency statistics, by default None
505+
max_features: int, optional
506+
Maximum number of displayed features, by default 10
507+
max_points : int, optional
508+
Maximum number of displayed datapoints per feature, by default 100
509+
file_name: string, optional
510+
Specify the save path of html files. If it is not provided, no file will be saved.
511+
auto_open: bool
512+
open automatically the plot, by default False
513+
514+
515+
Returns
516+
-------
517+
figure
518+
"""
519+
if self.x is None:
520+
raise ValueError('x must be defined in the compile to display the plot')
521+
if not isinstance(self.x, pd.DataFrame):
522+
raise ValueError('x must be a pandas DataFrame')
523+
if len(methods) != 2:
524+
raise ValueError('Choose 2 methods among methods of the contributions')
525+
526+
# Select contributions of input methods
527+
pair_indices = [self.methods.index(x) for x in methods]
528+
pair_weights = [self.weights[i] for i in pair_indices]
529+
530+
# Selection
531+
if selection is None:
532+
ind_max_points = self.x.sample(min(max_points, len(self.x))).index
533+
weights = [weight.iloc[ind_max_points] for weight in pair_weights]
534+
x = self.x.iloc[ind_max_points]
535+
elif isinstance(selection, list):
536+
if len(selection) == 1:
537+
raise ValueError('Selection must include multiple points')
538+
else:
539+
weights = [weight.iloc[selection] for weight in pair_weights]
540+
x = self.x.iloc[selection]
541+
else:
542+
raise ValueError('Parameter selection must be a list')
543+
544+
# Remove constant columns
545+
const_cols = x.loc[:, x.apply(pd.Series.nunique) == 1]
546+
x = x.drop(const_cols, axis=1)
547+
weights = [weight.drop(const_cols, axis=1) for weight in weights]
548+
549+
# Only keep features based on largest mean of absolute values
550+
mean_contributions = np.mean(np.abs(pd.concat(weights)))
551+
top_features = np.flip(mean_contributions.sort_values(ascending=False)[:max_features].keys())
552+
553+
fig = self.plot_pairwise_consistency(weights, x, top_features, methods, file_name, auto_open)
554+
555+
return fig
556+
557+
def plot_pairwise_consistency(self, weights, x, top_features, methods, file_name, auto_open):
558+
"""Plot the main graph displaying distances between methods across each feature and data point
559+
560+
Parameters
561+
----------
562+
weights : list
563+
List of 2 dataframes containing contributions for the selected points
564+
x : DataFrame
565+
Original input data filtered on selected points
566+
top_features : array
567+
Top features to display ordered by mean of absolute contributions across all the selected points
568+
methods : list
569+
List of explainbility methods to compare
570+
file_name: string
571+
Specify the save path of html files. If it is not provided, no file will be saved.
572+
auto_open: bool
573+
open automatically the plot
574+
575+
Returns
576+
-------
577+
figure
578+
"""
579+
# Look for existing OrdinalEncoder. If none, create one for string columns
580+
if isinstance(self.preprocessing, OrdinalEncoder):
581+
encoder = self.preprocessing
582+
else:
583+
categorical_features = [col for col in x.columns if x[col].dtype == 'object']
584+
encoder = OrdinalEncoder(cols=categorical_features,
585+
handle_unknown='ignore',
586+
return_df=True).fit(x)
587+
x = encoder.transform(x)
588+
589+
xaxis_title = "Difference of contributions between the 2 methods" \
590+
+ f"<span style='font-size: 12px;'><br />{methods[0]} - {methods[1]}</span>"
591+
yaxis_title = "Top features<span style='font-size: 12px;'><br />(Ordered by mean of absolute contributions)</span>"
592+
593+
fig = make_subplots(specs=[[{"secondary_y": True}]])
594+
595+
# Plot the distribution
596+
597+
for i, c in enumerate(top_features):
598+
599+
switch = False
600+
if c in encoder.cols:
601+
602+
switch = True
603+
604+
mapping = encoder.mapping[encoder.cols.index(c)]["mapping"]
605+
inverse_mapping = {v: k for k, v in mapping.to_dict().items()}
606+
feature_value = x[c].map(inverse_mapping)
607+
608+
hv_text = [f"<b>Feature value</b>: {i}<br><b>{methods[0]}</b>: {j}<br><b>{methods[1]}</b>: {k}<br><b>Diff</b>: {l}"
609+
for i, j, k, l in zip(feature_value if switch else x[c].round(3),
610+
weights[0][c].round(2),
611+
weights[1][c].round(2),
612+
(weights[0][c] - weights[1][c]).round(2))]
613+
614+
fig.add_trace(
615+
go.Violin(
616+
x=(weights[0][c] - weights[1][c]).values,
617+
name=c,
618+
points=False,
619+
fillcolor="rgba(255, 0, 0, 0.1)",
620+
line={"color": "black", "width": 0.5},
621+
showlegend=False,
622+
), secondary_y=False
623+
)
624+
625+
fig.add_trace(
626+
go.Scatter(
627+
x=(weights[0][c] - weights[1][c]).values,
628+
y=len(x)*[i] + np.random.normal(0, 0.1, len(x)),
629+
mode='markers',
630+
marker={"color": x[c].values,
631+
"colorscale": self.tuning_colorscale(x[c]),
632+
"opacity": 0.7},
633+
name=c,
634+
text=len(x)*[c],
635+
hovertext=hv_text,
636+
hovertemplate="<b>%{text}</b><br><br>" +
637+
"%{hovertext}<br>" +
638+
"<extra></extra>",
639+
showlegend=False,
640+
), secondary_y=True
641+
)
642+
643+
# Dummy invisible plot to add the color scale
644+
colorbar_trace = go.Scatter(
645+
x=[None],
646+
y=[None],
647+
mode="markers",
648+
marker=dict(
649+
size=1,
650+
color=[x.min(), x.max()],
651+
colorscale=self.tuning_colorscale(pd.Series(np.linspace(x.min().min(), x.max().max(), 10))),
652+
colorbar=dict(thickness=20,
653+
lenmode="pixels",
654+
len=400,
655+
yanchor="top",
656+
y=1.1,
657+
ypad=20,
658+
title="Feature values",
659+
tickvals=[x.min().min(), x.max().max()],
660+
ticktext=["Low", "High"]),
661+
showscale=True,
662+
),
663+
hoverinfo="none",
664+
showlegend=False,
665+
)
666+
667+
fig.add_trace(colorbar_trace)
668+
669+
self._update_pairwise_consistency_fig(fig=fig,
670+
top_features=top_features,
671+
xaxis_title=xaxis_title,
672+
yaxis_title=yaxis_title,
673+
file_name=file_name,
674+
auto_open=auto_open)
675+
676+
return fig
677+
678+
def _update_pairwise_consistency_fig(self, fig, top_features, xaxis_title, yaxis_title, file_name, auto_open):
679+
"""Function used for the pairwise_consistency_plot to update the layout of the plotly figure.
680+
681+
Parameters
682+
----------
683+
fig : figure
684+
Plotly figure
685+
top_features : array
686+
Top features to display ordered by mean of absolute contributions across all the selected points
687+
xaxis_title : str
688+
Title for the x-axis
689+
yaxis_title : str
690+
Title for the y-axis
691+
file_name: string
692+
Specify the save path of html files. If it is not provided, no file will be saved.
693+
auto_open: bool
694+
open automatically the plot
695+
"""
696+
title = "Pairwise comparison of Consistency:"
697+
title += "<span style='font-size: 16px;'>\
698+
<br />How are differences in contributions distributed across features?</span>"
699+
dict_t = copy.deepcopy(self._style_dict["dict_title_stability"])
700+
dict_xaxis = copy.deepcopy(self._style_dict["dict_xaxis"])
701+
dict_yaxis = copy.deepcopy(self._style_dict["dict_yaxis"])
702+
dict_xaxis['text'] = xaxis_title
703+
dict_yaxis['text'] = yaxis_title
704+
dict_t['text'] = title
705+
706+
fig.layout.yaxis.update(showticklabels=True)
707+
fig.layout.yaxis2.update(showticklabels=False)
708+
fig.update_layout(template="none",
709+
title=dict_t,
710+
xaxis_title=dict_xaxis,
711+
yaxis_title=dict_yaxis,
712+
yaxis=dict(range=[-0.7, len(top_features)-0.3]),
713+
yaxis2=dict(range=[-0.7, len(top_features)-0.3]),
714+
height=max(500, 40 * len(top_features)))
715+
716+
fig.update_yaxes(automargin=True, zeroline=False)
717+
fig.update_xaxes(automargin=True)
718+
719+
if file_name is not None:
720+
plot(fig, filename=file_name, auto_open=auto_open)

0 commit comments

Comments
 (0)