2
2
Subclasses implement interfaces for different ML frameworks such as TensorFlow or PyTorch.
3
3
All methods are in dice_ml.explainer_interfaces"""
4
4
5
- import warnings
6
5
from abc import ABC , abstractmethod
6
+ from collections .abc import Iterable
7
+
7
8
import numpy as np
8
9
import pandas as pd
10
+ from sklearn .neighbors import KDTree
9
11
from tqdm import tqdm
10
12
11
- from collections .abc import Iterable
12
- from sklearn .neighbors import KDTree
13
+ from dice_ml .constants import ModelTypes
13
14
from dice_ml .counterfactual_explanations import CounterfactualExplanations
14
15
from dice_ml .utils .exception import UserConfigValidationException
15
- from dice_ml .constants import ModelTypes
16
16
17
17
18
18
class ExplainerBase (ABC ):
@@ -85,6 +85,7 @@ def generate_counterfactuals(self, query_instances, total_CFs,
85
85
if posthoc_sparsity_algorithm == None :
86
86
posthoc_sparsity_algorithm = 'binary'
87
87
elif total_CFs > 50 and posthoc_sparsity_algorithm == 'linear' :
88
+ import warnings
88
89
warnings .warn ("The number of counterfactuals (total_CFs={}) generated per query instance could take much time; "
89
90
"if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to "
90
91
"'binary' search!" .format (total_CFs ))
@@ -98,6 +99,7 @@ def generate_counterfactuals(self, query_instances, total_CFs,
98
99
query_instances_list .append (query_instances [ix :(ix + 1 )])
99
100
elif isinstance (query_instances , Iterable ):
100
101
query_instances_list = query_instances
102
+
101
103
for query_instance in tqdm (query_instances_list ):
102
104
self .data_interface .set_continuous_feature_indexes (query_instance )
103
105
res = self ._generate_counterfactuals (
@@ -112,6 +114,9 @@ def generate_counterfactuals(self, query_instances, total_CFs,
112
114
verbose = verbose ,
113
115
** kwargs )
114
116
cf_examples_arr .append (res )
117
+
118
+ self ._check_any_counterfactuals_computed (cf_examples_arr = cf_examples_arr )
119
+
115
120
return CounterfactualExplanations (cf_examples_list = cf_examples_arr )
116
121
117
122
@abstractmethod
@@ -217,10 +222,12 @@ def local_feature_importance(self, query_instances, cf_examples_list=None,
217
222
if any ([len (cf_examples .final_cfs_df ) < 10 for cf_examples in cf_examples_list ]):
218
223
raise UserConfigValidationException (
219
224
"The number of counterfactuals generated per query instance should be "
220
- "greater than or equal to 10" )
225
+ "greater than or equal to 10 to compute feature importance for all query points " )
221
226
elif total_CFs < 10 :
222
- raise UserConfigValidationException ("The number of counterfactuals generated per "
223
- "query instance should be greater than or equal to 10" )
227
+ raise UserConfigValidationException (
228
+ "The number of counterfactuals requested per "
229
+ "query instance should be greater than or equal to 10 "
230
+ "to compute feature importance for all query points" )
224
231
importances = self .feature_importance (
225
232
query_instances ,
226
233
cf_examples_list = cf_examples_list ,
@@ -261,16 +268,25 @@ def global_feature_importance(self, query_instances, cf_examples_list=None,
261
268
input, and the global feature importance summarized over all inputs.
262
269
"""
263
270
if query_instances is not None and len (query_instances ) < 10 :
264
- raise UserConfigValidationException ("The number of query instances should be greater than or equal to 10" )
271
+ raise UserConfigValidationException (
272
+ "The number of query instances should be greater than or equal to 10 "
273
+ "to compute global feature importance over all query points" )
265
274
if cf_examples_list is not None :
266
- if any ([len (cf_examples .final_cfs_df ) < 10 for cf_examples in cf_examples_list ]):
275
+ if len (cf_examples_list ) < 10 :
276
+ raise UserConfigValidationException (
277
+ "The number of points for which counterfactuals generated should be "
278
+ "greater than or equal to 10 "
279
+ "to compute global feature importance" )
280
+ elif any ([len (cf_examples .final_cfs_df ) < 10 for cf_examples in cf_examples_list ]):
267
281
raise UserConfigValidationException (
268
282
"The number of counterfactuals generated per query instance should be "
269
- "greater than or equal to 10" )
283
+ "greater than or equal to 10"
284
+ "to compute global feature importance over all query points" )
270
285
elif total_CFs < 10 :
271
286
raise UserConfigValidationException (
272
287
"The number of counterfactuals generated per query instance should be greater "
273
- "than or equal to 10" )
288
+ "than or equal to 10"
289
+ "to compute global feature importance over all query points" )
274
290
importances = self .feature_importance (
275
291
query_instances ,
276
292
cf_examples_list = cf_examples_list ,
@@ -349,7 +365,7 @@ def feature_importance(self, query_instances, cf_examples_list=None,
349
365
continue
350
366
351
367
per_query_point_cfs = 0
352
- for index , row in df .iterrows ():
368
+ for _ , row in df .iterrows ():
353
369
per_query_point_cfs += 1
354
370
for col in self .data_interface .continuous_feature_names :
355
371
if not np .isclose (org_instance [col ].iat [0 ], row [col ]):
@@ -530,7 +546,7 @@ def misc_init(self, stopping_threshold, desired_class, desired_range, test_pred)
530
546
self .target_cf_class = np .array (
531
547
[[self .infer_target_cfs_class (desired_class , test_pred , self .num_output_nodes )]],
532
548
dtype = np .float32 )
533
- desired_class = self .target_cf_class [0 ][0 ]
549
+ desired_class = int ( self .target_cf_class [0 ][0 ])
534
550
if self .target_cf_class == 0 and self .stopping_threshold > 0.5 :
535
551
self .stopping_threshold = 0.25
536
552
elif self .target_cf_class == 1 and self .stopping_threshold < 0.5 :
@@ -695,3 +711,15 @@ def round_to_precision(self):
695
711
self .final_cfs_df [feature ] = self .final_cfs_df [feature ].astype (float ).round (precisions [ix ])
696
712
if self .final_cfs_df_sparse is not None :
697
713
self .final_cfs_df_sparse [feature ] = self .final_cfs_df_sparse [feature ].astype (float ).round (precisions [ix ])
714
+
715
+ def _check_any_counterfactuals_computed (self , cf_examples_arr ):
716
+ """Check if any counterfactuals were generated for any query point."""
717
+ no_cf_generated = True
718
+ # Check if any counterfactuals were generated for any query point
719
+ for cf_examples in cf_examples_arr :
720
+ if cf_examples .final_cfs_df is not None and len (cf_examples .final_cfs_df ) > 0 :
721
+ no_cf_generated = False
722
+ break
723
+ if no_cf_generated :
724
+ raise UserConfigValidationException (
725
+ "No counterfactuals found for any of the query points! Kindly check your configuration." )
0 commit comments