77# imports from sklearn
88from sklearn .model_selection import train_test_split
99from sklearn .ensemble import RandomForestRegressor
10- from sklearn .linear_model import ElasticNetCV
10+ from sklearn .linear_model import LinearRegression , ElasticNetCV
1111from local_mdi import local_mdi_score
1212
1313# timing imports
@@ -28,7 +28,7 @@ def simulate_data(rho, pve, seed):
2828
2929 np .random .seed (seed )
3030
31- n = 500 # number of samples
31+ n = 250 # number of samples
3232 p1 = 50 # number of correlated features
3333 p2 = 50 # number of uncorrelated features
3434
@@ -53,29 +53,23 @@ def simulate_data(rho, pve, seed):
5353 X = np .random .multivariate_normal (mu , Sigma , size = n )
5454
5555 y = partial_linear_lss_model (X = X , s = 2 , m = 3 , r = 2 , tau = 0 , beta = 1 , heritability = pve )
56-
56+
5757 return X , y
5858
59- def split_data (X , y , test_size , seed ):
60- # split data into train and test sets
61- X_train , X_test , y_train , y_test = train_test_split (X , y ,
62- test_size = test_size ,
63- random_state = seed )
64- return X_train , X_test , y_train , y_test
65-
6659def fit_models (X_train , y_train ):
6760
68- rf = RandomForestRegressor (n_estimators = 100 , min_samples_leaf = 5 ,
69- max_features = 0.33 , random_state = 42 )
61+ # fit rf
62+ rf = RandomForestRegressor (n_estimators = 100 , min_samples_leaf = 5 ,
63+ max_features = 0.33 , random_state = 42 )
7064 rf .fit (X_train , y_train )
71-
72- # elastic net rf+
65+
66+ # fit rf+
7367 rf_plus_elastic = RandomForestPlusRegressor (rf_model = rf ,
74- prediction_model = ElasticNetCV (cv = 3 ,
75- l1_ratio = [0.1 ,0.5 ,0.99 ],
76- max_iter = 2000 ,random_state = 42 ))
68+ prediction_model = ElasticNetCV (cv = 3 ,
69+ l1_ratio = [0.1 ,0.5 ,0.99 ],
70+ max_iter = 2000 , random_state = 42 ))
7771 rf_plus_elastic .fit (X_train , y_train )
78-
72+
7973 return rf , rf_plus_elastic
8074
8175def get_shap (X , shap_explainer ):
@@ -118,10 +112,10 @@ def get_lime(X: np.ndarray, rf):
118112
119113 return lime_values , lime_rankings
120114
121- def get_lmdi_plus ( X , lmdi_plus_explainer , ranking ):
115+ def get_lmdi ( X , y , lmdi_plus_explainer , ranking ):
122116
123117 # get feature importances
124- lmdi_plus = lmdi_plus_explainer .get_lmdi_plus_scores (X , ranking = ranking )
118+ lmdi_plus = lmdi_plus_explainer .get_lmdi_plus_scores (X , y , ranking = ranking )
125119
126120 lmdi_plus_rankings = np .argsort (- np .abs (lmdi_plus ), axis = 1 )
127121
@@ -149,8 +143,7 @@ def get_lmdi_plus(X, lmdi_plus_explainer, ranking):
149143 pve = args_dict ['pve' ]
150144 njobs = args_dict ['njobs' ]
151145
152- X , y = simulate_data (rho , pve , seed )
153- X_train , X_test , y_train , y_test = split_data (X , y , test_size = 0.5 , seed = seed )
146+ X_train , y_train = simulate_data (rho , pve , seed )
154147
155148 # end time
156149 end = time .time ()
@@ -176,13 +169,8 @@ def get_lmdi_plus(X, lmdi_plus_explainer, ranking):
176169
177170 # obtain shap feature importances
178171 shap_rf_explainer = shap .TreeExplainer (rf )
179- shap_rf_values , shap_rf_rankings = get_shap (X_test , shap_rf_explainer )
172+ shap_rf_values , shap_rf_rankings = get_shap (X_train , shap_rf_explainer )
180173
181- # obtain interventional shap feature importances
182- background = shap .sample (X_train , 150 , random_state = 150 )
183- interventional_shap_rf_explainer = shap .TreeExplainer (rf , data = background , feature_perturbation = "interventional" )
184- interventional_shap_rf_values , interventional_shap_rf_rankings = get_shap (X_test , interventional_shap_rf_explainer )
185-
186174 # end time
187175 end = time .time ()
188176
@@ -193,7 +181,7 @@ def get_lmdi_plus(X, lmdi_plus_explainer, ranking):
193181 start = time .time ()
194182
195183 # obtain LIME feature importances
196- lime_rf_values , lime_rf_rankings = get_lime (X_test , rf )
184+ lime_rf_values , lime_rf_rankings = get_lime (X_train , rf )
197185
198186 # end time
199187 end = time .time ()
@@ -204,15 +192,15 @@ def get_lmdi_plus(X, lmdi_plus_explainer, ranking):
204192 # start time
205193 start = time .time ()
206194
207- _ , lmdi_sutera_values = local_mdi_score (X_train , X_test , model = rf , absolute = False )
195+ _ , lmdi_sutera_values = local_mdi_score (X_train , X_train , model = rf , absolute = False )
208196 lmdi_sutera_rankings = np .argsort (- np .abs (lmdi_sutera_values ), axis = 1 )
209197
210198 # end time
211199 end = time .time ()
212200
213- print (f"Progress Message 5/6: LMDI values/rankings obtained." )
201+ print (f"Progress Message 5/6: Local MDI values/rankings obtained." )
214202 print (f"Step #5 took { end - start } seconds." )
215-
203+
216204 # start time
217205 start = time .time ()
218206
@@ -224,27 +212,24 @@ def get_lmdi_plus(X, lmdi_plus_explainer, ranking):
224212 lfi_rankings = {}
225213
226214 # obtain feature importances
227- lmdi_plus_values , lmdi_plus_rankings = get_lmdi_plus ( X_test ,
215+ lmdi_plus_values , lmdi_plus_rankings = get_lmdi ( X_train , y_train ,
228216 lmdi_plus_rf_explainer ,
229217 ranking = True )
230-
231- # end time
232- end = time .time ()
233-
234- print (f"Progress Message 6/6: LMDI+ values/rankings obtained." )
235- print (f"Step #6 took { end - start } seconds." )
236-
237218 lfi_values ["lmdi_plus" ] = lmdi_plus_values
238219 lfi_rankings ["lmdi_plus" ] = lmdi_plus_rankings
239220 lfi_rankings ["shap" ] = shap_rf_rankings
240221 lfi_values ["shap" ] = shap_rf_values
241- lfi_rankings ["interventional_shap" ] = interventional_shap_rf_rankings
242- lfi_values ["interventional_shap" ] = interventional_shap_rf_values
243222 lfi_rankings ["lime" ] = lime_rf_rankings
244223 lfi_values ["lime" ] = lime_rf_values
245224 lfi_rankings ["lmdi_sutera" ] = lmdi_sutera_rankings
246225 lfi_values ["lmdi_sutera" ] = lmdi_sutera_values
247226
227+ # end time
228+ end = time .time ()
229+
230+ print (f"Progress Message 6/6: LMDI+ values/rankings obtained." )
231+ print (f"Step #6 took { end - start } seconds." )
232+
248233 result_dir = oj (os .path .dirname (os .path .realpath (__file__ )),
249234 f'results/pve{ pve } /rho{ rho } /seed{ seed } ' )
250235
0 commit comments