Skip to content

Commit 18ffab6

Browse files
Merge pull request #612 from guillaume-vignal/feature/refacto_report_plot
Enhance Plot Functionality and Consistency for Additional Visualizations
2 parents 455fdc3 + 2264563 commit 18ffab6

File tree

15 files changed

+1878
-311
lines changed

15 files changed

+1878
-311
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ report = [
5454
"nbconvert>=6.0.7",
5555
"papermill>=2.0.0",
5656
"jupyter-client>=7.4.0",
57-
"seaborn==0.12.2",
5857
"notebook",
5958
"Jinja2>=2.11.0",
6059
"phik",

shapash/explainer/smart_plotter.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import math
66
import random
7+
from typing import Optional
78

89
import numpy as np
910
import pandas as pd
@@ -16,11 +17,12 @@
1617
from shapash.plots.plot_bar_chart import plot_bar_chart
1718
from shapash.plots.plot_contribution import plot_scatter, plot_violin
1819
from shapash.plots.plot_correlations import plot_correlations
20+
from shapash.plots.plot_evaluation_metrics import plot_confusion_matrix, plot_scatter_prediction
1921
from shapash.plots.plot_feature_importance import plot_feature_importance
2022
from shapash.plots.plot_interactions import plot_interactions_scatter, plot_interactions_violin, update_interactions_fig
2123
from shapash.plots.plot_line_comparison import plot_line_comparison
22-
from shapash.plots.plot_scatter_prediction import plot_scatter_prediction
2324
from shapash.plots.plot_stability import plot_amplitude_vs_stability, plot_stability_distribution
25+
from shapash.plots.plot_univariate import plot_distribution
2426
from shapash.style.style_utils import colors_loading, define_style, select_palette
2527
from shapash.utils.sampling import subset_sampling
2628
from shapash.utils.utils import (
@@ -1852,3 +1854,119 @@ def scatter_plot_prediction(
18521854
)
18531855

18541856
return fig
1857+
1858+
def confusion_matrix_plot(
1859+
self,
1860+
width: int = 700,
1861+
height: int = 500,
1862+
file_name=None,
1863+
auto_open=False,
1864+
):
1865+
"""
1866+
Returns a matplotlib figure containing a confusion matrix that is computed using y_true and
1867+
y_pred parameters.
1868+
1869+
Parameters
1870+
----------
1871+
y_true : array-like
1872+
Ground truth (correct) target values.
1873+
y_pred : array-like
1874+
Estimated targets as returned by a classifier.
1875+
colors_dict : dict
1876+
dict of colors used
1877+
width : int, optional, default=7
1878+
The width of the generated figure, in inches.
1879+
height : int, optional, default=4
1880+
The height of the generated figure, in inches.
1881+
1882+
Returns
1883+
-------
1884+
matplotlib.pyplot.Figure
1885+
"""
1886+
1887+
# Classification Case
1888+
if self._explainer._case == "classification":
1889+
y_true = self._explainer.y_target.iloc[:, 0]
1890+
y_pred = self._explainer.y_pred.iloc[:, 0]
1891+
if self._explainer.label_dict is not None:
1892+
y_true = y_true.map(self._explainer.label_dict)
1893+
y_pred = y_pred.map(self._explainer.label_dict)
1894+
# Regression Case
1895+
elif self._explainer._case == "regression":
1896+
raise (ValueError("Confusion matrix is only available for classification case"))
1897+
1898+
return plot_confusion_matrix(
1899+
y_true=y_true,
1900+
y_pred=y_pred,
1901+
colors_dict=self._style_dict,
1902+
width=width,
1903+
height=height,
1904+
file_name=file_name,
1905+
auto_open=auto_open,
1906+
)
1907+
1908+
def distribution_plot(
1909+
self,
1910+
col: str,
1911+
hue: Optional[str] = None,
1912+
width: int = 700,
1913+
height: int = 500,
1914+
nb_cat_max: int = 7,
1915+
nb_hue_max: int = 7,
1916+
file_name=None,
1917+
auto_open=False,
1918+
) -> go.Figure:
1919+
"""
1920+
Generate a Plotly figure displaying the univariate distribution of a feature
1921+
(continuous or categorical) in the dataset.
1922+
1923+
For categorical features with too many unique categories, the least frequent
1924+
categories are grouped into a new 'Other' category to ensure the plot remains
1925+
readable. Continuous features are visualized using KDE plots.
1926+
1927+
The input DataFrame must contain the column of interest (`col`) and a second column
1928+
(`hue`) used to distinguish between two groups (e.g., 'train' and 'test').
1929+
1930+
Parameters
1931+
----------
1932+
col : str
1933+
The name of the column of interest whose distribution is to be visualized.
1934+
hue : Optional[str], optional
1935+
The name of the column used to differentiate between groups.
1936+
width : int, optional, default=700
1937+
The width of the generated figure, in pixels.
1938+
height : int, optional, default=500
1939+
The height of the generated figure, in pixels.
1940+
nb_cat_max : int, optional, default=7
1941+
Maximum number of categories to display. Categories beyond this limit
1942+
are grouped into a new 'Other' category (only for categorical features).
1943+
nb_hue_max : int, optional, default=7
1944+
Maximum number of hue categories to display. Categories beyond this limit
1945+
are grouped into a new 'Other' category.
1946+
file_name : str, optional
1947+
Path to save the plot as an HTML file. If None, the plot will not be saved, by default None.
1948+
auto_open : bool, optional
1949+
If True, the plot will automatically open in a web browser after being generated, by default False.
1950+
1951+
Returns
1952+
-------
1953+
go.Figure
1954+
A Plotly figure object representing the distribution of the feature.
1955+
"""
1956+
if self._explainer.y_target is not None:
1957+
data = pd.concat([self._explainer.x_init, self._explainer.y_target], axis=1)
1958+
else:
1959+
data = self._explainer.x_init
1960+
1961+
return plot_distribution(
1962+
data,
1963+
col,
1964+
hue=hue,
1965+
colors_dict=self._style_dict,
1966+
width=width,
1967+
height=height,
1968+
nb_cat_max=nb_cat_max,
1969+
nb_hue_max=nb_hue_max,
1970+
file_name=file_name,
1971+
auto_open=auto_open,
1972+
)

shapash/plots/plot_correlations.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import numpy as np
24
import pandas as pd
35
import scipy.cluster.hierarchy as sch
@@ -6,12 +8,14 @@
68
from plotly.subplots import make_subplots
79

810
from shapash.manipulation.summarize import compute_corr
11+
from shapash.style.style_utils import define_style, get_palette
912
from shapash.utils.utils import adjust_title_height, compute_top_correlations_features, suffix_duplicates
1013

1114

1215
def plot_correlations(
1316
df,
14-
style_dict,
17+
style_dict: Optional[dict] = None,
18+
palette_name: str = "default",
1519
features_dict=None,
1620
optimized=False,
1721
max_features=20,
@@ -35,6 +39,8 @@ def plot_correlations(
3539
DataFrame for which we want to compute correlations.
3640
style_dict: dict
3741
the different styles used in the different outputs of Shapash
42+
palette_name : str, optional, default="default"
43+
The name of the color palette to be used if `colors_dict` is not provided.
3844
features_dict: dict (default: None)
3945
Dictionary mapping technical feature names to domain names.
4046
optimized : boolean, optional
@@ -123,6 +129,15 @@ def prepare_corr_matrix(df_subset):
123129
list_features_shorten = suffix_duplicates(list_features_shorten)
124130
return corr, list_features, list_features_shorten
125131

132+
if style_dict:
133+
style_dict_default = {}
134+
keys = ["dict_title", "init_contrib_colorscale"]
135+
if any(key not in style_dict for key in keys):
136+
style_dict_default = define_style(get_palette(palette_name))
137+
style_dict_default.update(style_dict)
138+
else:
139+
style_dict_default = define_style(get_palette(palette_name))
140+
126141
if features_dict is None:
127142
features_dict = {}
128143

@@ -203,10 +218,10 @@ def prepare_corr_matrix(df_subset):
203218
if len(list_features) < len(df.drop(features_to_hide, axis=1).columns):
204219
subtitle = f"Top {len(list_features)} correlations"
205220
title += f"<span style='font-size: 12px;'><br />{subtitle}</span>"
206-
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}
221+
dict_t = style_dict_default["dict_title"] | {"text": title, "y": adjust_title_height(height)}
207222

208223
fig.update_layout(
209-
coloraxis=dict(colorscale=["rgb(255, 255, 255)"] + style_dict["init_contrib_colorscale"][5:-1]),
224+
coloraxis=dict(colorscale=["rgb(255, 255, 255)"] + style_dict_default["init_contrib_colorscale"][5:-1]),
210225
showlegend=True,
211226
title=dict_t,
212227
width=width,
Lines changed: 154 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from typing import Optional, Union
2+
13
import numpy as np
24
import pandas as pd
35
from plotly import graph_objs as go
46
from plotly.offline import plot
57

8+
from shapash.style.style_utils import define_style, get_palette
69
from shapash.utils.sampling import subset_sampling
710
from shapash.utils.utils import adjust_title_height, truncate_str, tuning_colorscale
811

@@ -356,7 +359,6 @@ def _prediction_regression_plot(y_target, y_pred, prediction_error, list_ind, st
356359
fig = go.Figure()
357360

358361
subtitle = None
359-
prediction_error = prediction_error
360362
if prediction_error is not None:
361363
if (y_target == 0).any().iloc[0]:
362364
subtitle = "Prediction Error = abs(True Values - Predicted Values)"
@@ -458,8 +460,8 @@ def _prediction_regression_plot(y_target, y_pred, prediction_error, list_ind, st
458460
"y": 1.1,
459461
}
460462
range_axis = [
461-
min(min(y_target_values), min(y_pred_flatten)),
462-
max(max(y_target_values), max(y_pred_flatten)),
463+
min(y_target_values.min(), y_pred_flatten.min()),
464+
max(y_target_values.max(), y_pred_flatten.max()),
463465
]
464466
fig.update_xaxes(range=range_axis)
465467
fig.update_yaxes(range=range_axis)
@@ -479,3 +481,152 @@ def _prediction_regression_plot(y_target, y_pred, prediction_error, list_ind, st
479481
)
480482

481483
return fig, subtitle
484+
485+
486+
def plot_confusion_matrix(
487+
y_true: Union[np.ndarray, list],
488+
y_pred: Union[np.ndarray, list],
489+
colors_dict: Optional[dict] = None,
490+
width: int = 700,
491+
height: int = 500,
492+
palette_name: str = "default",
493+
file_name=None,
494+
auto_open=False,
495+
) -> go.Figure:
496+
"""
497+
Creates an interactive confusion matrix using Plotly.
498+
499+
Parameters
500+
----------
501+
y_true : array-like
502+
Ground truth (correct) target values.
503+
y_pred : array-like
504+
Estimated targets as returned by a classifier.
505+
colors_dict : dict, optional
506+
Custom colors for the confusion matrix.
507+
width : int, optional
508+
The width of the figure in pixels.
509+
height : int, optional
510+
The height of the figure in pixels.
511+
palette_name : str, optional
512+
The color palette to use for the heatmap.
513+
file_name: string, optional
514+
Specify the save path of html files. If None, no file will be saved.
515+
auto_open: bool, optional
516+
Automatically open the plot.
517+
518+
Returns
519+
-------
520+
go.Figure
521+
The generated confusion matrix as a Plotly figure.
522+
"""
523+
# Create a confusion matrix as a DataFrame
524+
labels = sorted(set(y_true).union(set(y_pred)))
525+
se_y_true = pd.Series(y_true, name="Actual")
526+
se_y_pred = pd.Series(y_pred, name="Predicted")
527+
df_cm = pd.crosstab(se_y_true, se_y_pred).reindex(index=labels, columns=labels, fill_value=0)
528+
529+
if colors_dict:
530+
style_dict = {}
531+
keys = ["dict_title", "init_confusion_matrix_colorscale", "dict_xaxis", "dict_yaxis"]
532+
if any(key not in colors_dict for key in keys):
533+
style_dict = define_style(get_palette(palette_name))
534+
style_dict.update(colors_dict)
535+
else:
536+
style_dict = define_style(get_palette(palette_name))
537+
538+
init_colorscale = style_dict["init_confusion_matrix_colorscale"]
539+
linspace = np.linspace(0, 1, len(init_colorscale))
540+
col_scale = [(value, color) for value, color in zip(linspace, init_colorscale)]
541+
542+
# Convert the DataFrame to a NumPy array
543+
x_labels = list(df_cm.columns)
544+
y_labels = list(df_cm.index)
545+
z = df_cm.loc[x_labels, y_labels].values
546+
547+
title = "Confusion Matrix"
548+
dict_t = style_dict["dict_title"] | {"text": title, "y": adjust_title_height(height)}
549+
dict_xaxis = style_dict["dict_xaxis"] | {"text": se_y_pred.name}
550+
dict_yaxis = style_dict["dict_yaxis"] | {"text": se_y_true.name}
551+
552+
# Determine if labels are numeric
553+
x_numeric = all(str(label).isdigit() for label in x_labels)
554+
y_numeric = all(str(label).isdigit() for label in y_labels)
555+
556+
hv_text = [
557+
[f"Actual: {y}<br>Predicted: {x}<br>Count: {value}" for x, value in zip(x_labels, row)]
558+
for y, row in zip(y_labels, z)
559+
]
560+
561+
if not x_numeric:
562+
if len(x_labels) < 6:
563+
k = 10
564+
else:
565+
k = 6
566+
567+
# Shorten labels that exceed the threshold
568+
x_labels = [x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k + 3 else x for x in x_labels]
569+
570+
if not y_numeric:
571+
if len(y_labels) < 6:
572+
k = 10
573+
else:
574+
k = 6
575+
576+
# Shorten labels that exceed the threshold
577+
y_labels = [x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k + 3 else x for x in y_labels]
578+
579+
# Create the heatmap using go.Heatmap
580+
heatmap = go.Heatmap(
581+
z=z,
582+
x=x_labels,
583+
y=y_labels,
584+
colorscale=col_scale,
585+
hovertext=hv_text,
586+
hovertemplate="%{hovertext}<extra></extra>",
587+
showscale=True,
588+
)
589+
590+
fig = go.Figure(data=[heatmap])
591+
592+
# Add annotations for each cell
593+
annotations = []
594+
for i, y_label in enumerate(y_labels):
595+
for j, x_label in enumerate(x_labels):
596+
annotations.append(
597+
dict(
598+
x=x_label,
599+
y=y_label,
600+
text=str(z[i][j]),
601+
showarrow=False,
602+
font=dict(color="black" if z[i][j] < z.max() / 2 else "white"),
603+
)
604+
)
605+
606+
# Update layout
607+
fig.update_layout(
608+
annotations=annotations,
609+
title=dict_t,
610+
xaxis=dict(
611+
title=dict_xaxis,
612+
tickangle=45,
613+
tickmode="array" if x_numeric else "linear",
614+
tickvals=[int(label) for label in x_labels] if x_numeric else None,
615+
ticktext=x_labels if x_numeric else None,
616+
),
617+
yaxis=dict(
618+
title=dict_yaxis,
619+
autorange="reversed", # Reverse y-axis to match conventional confusion matrix
620+
tickmode="array" if y_numeric else "linear",
621+
tickvals=[int(label) for label in y_labels] if y_numeric else None,
622+
ticktext=y_labels if y_numeric else None,
623+
),
624+
width=width,
625+
height=height,
626+
margin=dict(l=150, r=20, t=100, b=70),
627+
)
628+
629+
if file_name:
630+
plot(fig, filename=file_name, auto_open=auto_open)
631+
632+
return fig

0 commit comments

Comments
 (0)