@@ -1026,6 +1026,8 @@ def explain_global(self, name=None):
10261026
10271027 # Add per feature graph
10281028 data_dicts = []
1029+ feature_list = []
1030+ density_list = []
10291031 for attribute_set_index , attribute_set in enumerate (self .attribute_sets_ ):
10301032 model_graph = self .attribute_set_models_ [attribute_set_index ]
10311033
@@ -1038,21 +1040,37 @@ def explain_global(self, name=None):
10381040 # bin_counts = self.preprocessor_.get_bin_counts(
10391041 # attribute_indexes[0]
10401042 # )
1043+ scores = list (model_graph )
1044+ upper_bounds = list (model_graph + errors )
1045+ lower_bounds = list (model_graph - errors )
1046+ density_dict = {
1047+ "names" : self .preprocessor_ .get_hist_edges (
1048+ attribute_indexes [0 ]
1049+ ),
1050+ "scores" : self .preprocessor_ .get_hist_counts (
1051+ attribute_indexes [0 ]
1052+ ),
1053+ }
1054+
1055+ feature_dict = {
1056+ "type" : "univariate" ,
1057+ "names" : bin_labels ,
1058+ "scores" : scores ,
1059+ "scores_range" : bounds ,
1060+ "upper_bounds" : upper_bounds ,
1061+ "lower_bounds" : lower_bounds ,
1062+ }
1063+ feature_list .append (feature_dict )
1064+ density_list .append (density_dict )
1065+
10411066 data_dict = {
10421067 "type" : "univariate" ,
10431068 "names" : bin_labels ,
1044- "scores" : list ( model_graph ) ,
1069+ "scores" : scores ,
10451070 "scores_range" : bounds ,
1046- "upper_bounds" : list (model_graph + errors ),
1047- "lower_bounds" : list (model_graph - errors ),
1048- "density" : {
1049- "names" : self .preprocessor_ .get_hist_edges (
1050- attribute_indexes [0 ]
1051- ),
1052- "scores" : self .preprocessor_ .get_hist_counts (
1053- attribute_indexes [0 ]
1054- ),
1055- },
1071+ "upper_bounds" : upper_bounds ,
1072+ "lower_bounds" : lower_bounds ,
1073+ "density" : density_dict ,
10561074 }
10571075 data_dicts .append (data_dict )
10581076 elif len (attribute_indexes ) == 2 :
@@ -1062,6 +1080,17 @@ def explain_global(self, name=None):
10621080 bin_labels_right = self .preprocessor_ .get_bin_labels (
10631081 attribute_indexes [1 ]
10641082 )
1083+
1084+ feature_dict = {
1085+ "type" : "pairwise" ,
1086+ "left_names" : bin_labels_left ,
1087+ "right_names" : bin_labels_right ,
1088+ "scores" : model_graph ,
1089+ "scores_range" : bounds ,
1090+ }
1091+ feature_list .append (feature_dict )
1092+ density_list .append ({})
1093+
10651094 data_dict = {
10661095 "type" : "pairwise" ,
10671096 "left_names" : bin_labels_left ,
@@ -1078,7 +1107,20 @@ def explain_global(self, name=None):
10781107 "names" : self .feature_names ,
10791108 "scores" : self .mean_abs_scores_ ,
10801109 }
1081- internal_obj = {"overall" : overall_dict , "specific" : data_dicts }
1110+ internal_obj = {"overall" : overall_dict , "specific" : data_dicts , "mli" : [
1111+ {
1112+ "explanation_type" : "ebm_global" ,
1113+ "value" : {
1114+ "feature_list" : feature_list
1115+ }
1116+ },
1117+ {
1118+ "explanation_type" : "density" ,
1119+ "value" : {
1120+ "density" : density_list
1121+ }
1122+ }
1123+ ]}
10821124
10831125 return EBMExplanation (
10841126 "global" ,
@@ -1134,12 +1176,31 @@ def explain_local(self, X, y=None, name=None):
11341176 else :
11351177 scores = EBMUtils .regressor_predict (instances , self )
11361178
1179+ perf_list = []
11371180 for row_idx in range (n_rows ):
1138- data_dicts [row_idx ]["perf" ] = perf_dict (y , scores , row_idx )
1181+ perf = perf_dict (y , scores , row_idx )
1182+ perf_list .append (perf )
1183+ data_dicts [row_idx ]["perf" ] = perf
11391184
11401185 selector = gen_local_selector (instances , y , scores )
11411186
1142- internal_obj = {"overall" : None , "specific" : data_dicts }
1187+ internal_obj = {"overall" : None , "specific" : data_dicts , "mli" : [
1188+ {
1189+ "explanation_type" : "ebm_local" ,
1190+ "value" : {
1191+ "scores" : self .attribute_set_models_ ,
1192+ "intercept" : self .intercept_ ,
1193+ "perf" : perf_list ,
1194+ },
1195+ }
1196+ ],
1197+ }
1198+ internal_obj ["mli" ].append (
1199+ {
1200+ "explanation_type" : "evaluation_dataset" ,
1201+ "value" : {"dataset_x" : X , "dataset_y" : y },
1202+ }
1203+ )
11431204
11441205 return EBMExplanation (
11451206 "local" ,
0 commit comments