1+ from typing import Optional , Union
2+
13import numpy as np
24import pandas as pd
35from plotly import graph_objs as go
46from plotly .offline import plot
57
8+ from shapash .style .style_utils import define_style , get_palette
69from shapash .utils .sampling import subset_sampling
710from shapash .utils .utils import adjust_title_height , truncate_str , tuning_colorscale
811
@@ -356,7 +359,6 @@ def _prediction_regression_plot(y_target, y_pred, prediction_error, list_ind, st
356359 fig = go .Figure ()
357360
358361 subtitle = None
359- prediction_error = prediction_error
360362 if prediction_error is not None :
361363 if (y_target == 0 ).any ().iloc [0 ]:
362364 subtitle = "Prediction Error = abs(True Values - Predicted Values)"
@@ -458,8 +460,8 @@ def _prediction_regression_plot(y_target, y_pred, prediction_error, list_ind, st
458460 "y" : 1.1 ,
459461 }
460462 range_axis = [
461- min (min (y_target_values ), min (y_pred_flatten )),
462- max (max (y_target_values ), max (y_pred_flatten )),
463+ min (y_target_values . min (), y_pred_flatten . min ()),
464+ max (y_target_values . max (), y_pred_flatten . max ()),
463465 ]
464466 fig .update_xaxes (range = range_axis )
465467 fig .update_yaxes (range = range_axis )
@@ -479,3 +481,152 @@ def _prediction_regression_plot(y_target, y_pred, prediction_error, list_ind, st
479481 )
480482
481483 return fig , subtitle
484+
485+
486+ def plot_confusion_matrix (
487+ y_true : Union [np .ndarray , list ],
488+ y_pred : Union [np .ndarray , list ],
489+ colors_dict : Optional [dict ] = None ,
490+ width : int = 700 ,
491+ height : int = 500 ,
492+ palette_name : str = "default" ,
493+ file_name = None ,
494+ auto_open = False ,
495+ ) -> go .Figure :
496+ """
497+ Creates an interactive confusion matrix using Plotly.
498+
499+ Parameters
500+ ----------
501+ y_true : array-like
502+ Ground truth (correct) target values.
503+ y_pred : array-like
504+ Estimated targets as returned by a classifier.
505+ colors_dict : dict, optional
506+ Custom colors for the confusion matrix.
507+ width : int, optional
508+ The width of the figure in pixels.
509+ height : int, optional
510+ The height of the figure in pixels.
511+ palette_name : str, optional
512+ The color palette to use for the heatmap.
513+ file_name: string, optional
514+ Specify the save path of html files. If None, no file will be saved.
515+ auto_open: bool, optional
516+ Automatically open the plot.
517+
518+ Returns
519+ -------
520+ go.Figure
521+ The generated confusion matrix as a Plotly figure.
522+ """
523+ # Create a confusion matrix as a DataFrame
524+ labels = sorted (set (y_true ).union (set (y_pred )))
525+ se_y_true = pd .Series (y_true , name = "Actual" )
526+ se_y_pred = pd .Series (y_pred , name = "Predicted" )
527+ df_cm = pd .crosstab (se_y_true , se_y_pred ).reindex (index = labels , columns = labels , fill_value = 0 )
528+
529+ if colors_dict :
530+ style_dict = {}
531+ keys = ["dict_title" , "init_confusion_matrix_colorscale" , "dict_xaxis" , "dict_yaxis" ]
532+ if any (key not in colors_dict for key in keys ):
533+ style_dict = define_style (get_palette (palette_name ))
534+ style_dict .update (colors_dict )
535+ else :
536+ style_dict = define_style (get_palette (palette_name ))
537+
538+ init_colorscale = style_dict ["init_confusion_matrix_colorscale" ]
539+ linspace = np .linspace (0 , 1 , len (init_colorscale ))
540+ col_scale = [(value , color ) for value , color in zip (linspace , init_colorscale )]
541+
542+ # Convert the DataFrame to a NumPy array
543+ x_labels = list (df_cm .columns )
544+ y_labels = list (df_cm .index )
545+ z = df_cm .loc [x_labels , y_labels ].values
546+
547+ title = "Confusion Matrix"
548+ dict_t = style_dict ["dict_title" ] | {"text" : title , "y" : adjust_title_height (height )}
549+ dict_xaxis = style_dict ["dict_xaxis" ] | {"text" : se_y_pred .name }
550+ dict_yaxis = style_dict ["dict_yaxis" ] | {"text" : se_y_true .name }
551+
552+ # Determine if labels are numeric
553+ x_numeric = all (str (label ).isdigit () for label in x_labels )
554+ y_numeric = all (str (label ).isdigit () for label in y_labels )
555+
556+ hv_text = [
557+ [f"Actual: { y } <br>Predicted: { x } <br>Count: { value } " for x , value in zip (x_labels , row )]
558+ for y , row in zip (y_labels , z )
559+ ]
560+
561+ if not x_numeric :
562+ if len (x_labels ) < 6 :
563+ k = 10
564+ else :
565+ k = 6
566+
567+ # Shorten labels that exceed the threshold
568+ x_labels = [x .replace (x [k + k // 2 : - k + k // 2 ], "..." ) if len (x ) > 2 * k + 3 else x for x in x_labels ]
569+
570+ if not y_numeric :
571+ if len (y_labels ) < 6 :
572+ k = 10
573+ else :
574+ k = 6
575+
576+ # Shorten labels that exceed the threshold
577+ y_labels = [x .replace (x [k + k // 2 : - k + k // 2 ], "..." ) if len (x ) > 2 * k + 3 else x for x in y_labels ]
578+
579+ # Create the heatmap using go.Heatmap
580+ heatmap = go .Heatmap (
581+ z = z ,
582+ x = x_labels ,
583+ y = y_labels ,
584+ colorscale = col_scale ,
585+ hovertext = hv_text ,
586+ hovertemplate = "%{hovertext}<extra></extra>" ,
587+ showscale = True ,
588+ )
589+
590+ fig = go .Figure (data = [heatmap ])
591+
592+ # Add annotations for each cell
593+ annotations = []
594+ for i , y_label in enumerate (y_labels ):
595+ for j , x_label in enumerate (x_labels ):
596+ annotations .append (
597+ dict (
598+ x = x_label ,
599+ y = y_label ,
600+ text = str (z [i ][j ]),
601+ showarrow = False ,
602+ font = dict (color = "black" if z [i ][j ] < z .max () / 2 else "white" ),
603+ )
604+ )
605+
606+ # Update layout
607+ fig .update_layout (
608+ annotations = annotations ,
609+ title = dict_t ,
610+ xaxis = dict (
611+ title = dict_xaxis ,
612+ tickangle = 45 ,
613+ tickmode = "array" if x_numeric else "linear" ,
614+ tickvals = [int (label ) for label in x_labels ] if x_numeric else None ,
615+ ticktext = x_labels if x_numeric else None ,
616+ ),
617+ yaxis = dict (
618+ title = dict_yaxis ,
619+ autorange = "reversed" , # Reverse y-axis to match conventional confusion matrix
620+ tickmode = "array" if y_numeric else "linear" ,
621+ tickvals = [int (label ) for label in y_labels ] if y_numeric else None ,
622+ ticktext = y_labels if y_numeric else None ,
623+ ),
624+ width = width ,
625+ height = height ,
626+ margin = dict (l = 150 , r = 20 , t = 100 , b = 70 ),
627+ )
628+
629+ if file_name :
630+ plot (fig , filename = file_name , auto_open = auto_open )
631+
632+ return fig
0 commit comments