1717from shapash .backend import ShapBackend
1818from shapash .explainer .multi_decorator import MultiDecorator
1919from shapash .explainer .smart_state import SmartState
20+ from shapash .plots .plot_bar_chart import plot_bar_chart
21+ from shapash .plots .plot_feature_importance import _plot_features_import
2022from shapash .plots .plot_line_comparison import plot_line_comparison
2123from shapash .style .style_utils import get_palette
2224from shapash .utils .check import check_model
@@ -660,7 +662,7 @@ def test_plot_bar_chart_1(self):
660662 )
661663 expected_output_fig = go .Figure (data = bars , layout = go .Layout (yaxis = dict (type = "category" )))
662664 self .smart_explainer ._case = "regression"
663- fig_output = self . smart_explainer . plot . _plot_bar_chart ("ind" , var_dict , x_val , contributions )
665+ fig_output = plot_bar_chart ("ind" , var_dict , x_val , contributions , self . smart_explainer . plot . _style_dict )
664666 for part in list (zip (fig_output .data , expected_output_fig .data )):
665667 assert part [0 ].x == part [1 ].x
666668 assert part [0 ].y == part [1 ].y
@@ -683,7 +685,7 @@ def test_plot_bar_chart_2(self):
683685 expected_output_fig = go .Figure (data = bars , layout = go .Layout (yaxis = dict (type = "category" )))
684686
685687 self .smart_explainer ._case = "regression"
686- fig_output = self . smart_explainer . plot . _plot_bar_chart ("ind" , var_dict , x_val , contributions )
688+ fig_output = plot_bar_chart ("ind" , var_dict , x_val , contributions , self . smart_explainer . plot . _style_dict )
687689 for part in list (zip (fig_output .data , expected_output_fig .data )):
688690 assert part [0 ].x == part [1 ].x
689691 assert part [0 ].y == part [1 ].y
@@ -1126,8 +1128,9 @@ def test_plot_features_import_1(self):
11261128 """
11271129 Unit test plot features import 1
11281130 """
1131+ xpl = self .smart_explainer
11291132 serie1 = pd .Series ([0.131 , 0.51 ], index = ["col1" , "col2" ])
1130- output = self . smart_explainer . plot . _plot_features_import (serie1 )
1133+ output = _plot_features_import (serie1 , xpl . plot . _style_dict , {} )
11311134 data = go .Bar (x = serie1 , y = serie1 .index , name = "Global" , orientation = "h" )
11321135
11331136 expected_output = go .Figure (data = data )
@@ -1140,9 +1143,10 @@ def test_plot_features_import_2(self):
11401143 """
11411144 Unit test plot features import 2
11421145 """
1146+ xpl = self .smart_explainer
11431147 serie1 = pd .Series ([0.131 , 0.51 ], index = ["col1" , "col2" ])
11441148 serie2 = pd .Series ([0.33 , 0.11 ], index = ["col1" , "col2" ])
1145- output = self . smart_explainer . plot . _plot_features_import (serie1 , serie2 )
1149+ output = _plot_features_import (serie1 , xpl . plot . _style_dict , {}, feature_imp2 = serie2 )
11461150 data1 = go .Bar (x = serie1 , y = serie1 .index , name = "Global" , orientation = "h" )
11471151 data2 = go .Bar (x = serie2 , y = serie2 .index , name = "Subset" , orientation = "h" )
11481152 expected_output = go .Figure (data = [data2 , data1 ])
0 commit comments