@@ -82,21 +82,22 @@ def generate_counterfactuals(self, query_instances, total_CFs,
82
82
raise UserConfigValidationException (
83
83
"The number of counterfactuals generated per query instance (total_CFs) should be a positive integer." )
84
84
if total_CFs > 10 :
85
- if posthoc_sparsity_algorithm == None :
85
+ if posthoc_sparsity_algorithm is None :
86
86
posthoc_sparsity_algorithm = 'binary'
87
- elif total_CFs > 50 and posthoc_sparsity_algorithm == 'linear' :
87
+ elif total_CFs > 50 and posthoc_sparsity_algorithm == 'linear' :
88
88
import warnings
89
- warnings .warn ("The number of counterfactuals (total_CFs={}) generated per query instance could take much time; "
90
- "if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to "
91
- "'binary' search!" .format (total_CFs ))
92
- elif posthoc_sparsity_algorithm == None :
89
+ warnings .warn (
90
+ "The number of counterfactuals (total_CFs={}) generated per query instance could take much time; "
91
+ "if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to "
92
+ "'binary' search!" .format (total_CFs ))
93
+ elif posthoc_sparsity_algorithm is None :
93
94
posthoc_sparsity_algorithm = 'linear'
94
95
95
96
cf_examples_arr = []
96
97
query_instances_list = []
97
98
if isinstance (query_instances , pd .DataFrame ):
98
99
for ix in range (query_instances .shape [0 ]):
99
- query_instances_list .append (query_instances [ix :(ix + 1 )])
100
+ query_instances_list .append (query_instances [ix :(ix + 1 )])
100
101
elif isinstance (query_instances , Iterable ):
101
102
query_instances_list = query_instances
102
103
@@ -190,11 +191,14 @@ def check_query_instance_validity(self, features_to_vary, permitted_range, query
190
191
191
192
if feature not in features_to_vary and permitted_range is not None :
192
193
if feature in permitted_range and feature in self .data_interface .continuous_feature_names :
193
- if not permitted_range [feature ][0 ] <= query_instance [feature ].values [0 ] <= permitted_range [feature ][1 ]:
194
- raise ValueError ("Feature:" , feature , "is outside the permitted range and isn't allowed to vary." )
194
+ if not permitted_range [feature ][0 ] <= query_instance [feature ].values [0 ] <= permitted_range [feature ][
195
+ 1 ]:
196
+ raise ValueError ("Feature:" , feature ,
197
+ "is outside the permitted range and isn't allowed to vary." )
195
198
elif feature in permitted_range and feature in self .data_interface .categorical_feature_names :
196
199
if query_instance [feature ].values [0 ] not in self .feature_range [feature ]:
197
- raise ValueError ("Feature:" , feature , "is outside the permitted range and isn't allowed to vary." )
200
+ raise ValueError ("Feature:" , feature ,
201
+ "is outside the permitted range and isn't allowed to vary." )
198
202
199
203
def local_feature_importance (self , query_instances , cf_examples_list = None ,
200
204
total_CFs = 10 ,
@@ -440,12 +444,13 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post
440
444
cfs_preds_sparse = []
441
445
442
446
for cf_ix in list (final_cfs_sparse .index ):
443
- current_pred = self .predict_fn_for_sparsity (final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
447
+ current_pred = self .predict_fn_for_sparsity (
448
+ final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
444
449
for feature in features_sorted :
445
450
# current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names])
446
451
# feat_ix = self.data_interface.continuous_feature_names.index(feature)
447
452
diff = query_instance [feature ].iat [0 ] - int (final_cfs_sparse .at [cf_ix , feature ])
448
- if (abs (diff ) <= quantiles [feature ]):
453
+ if (abs (diff ) <= quantiles [feature ]):
449
454
if posthoc_sparsity_algorithm == "linear" :
450
455
final_cfs_sparse = self .do_linear_search (diff , decimal_prec , query_instance , cf_ix ,
451
456
feature , final_cfs_sparse , current_pred )
@@ -466,13 +471,14 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
466
471
query_instance greedily until the prediction class changes."""
467
472
468
473
old_diff = diff
469
- change = (10 ** - decimal_prec [feature ]) # the minimal possible change for a feature
474
+ change = (10 ** - decimal_prec [feature ]) # the minimal possible change for a feature
470
475
current_pred = current_pred_orig
471
476
if self .model .model_type == ModelTypes .Classifier :
472
- while ((abs (diff ) > 10e-4 ) and (np .sign (diff * old_diff ) > 0 ) and self .is_cf_valid (current_pred )):
477
+ while ((abs (diff ) > 10e-4 ) and (np .sign (diff * old_diff ) > 0 ) and self .is_cf_valid (current_pred )):
473
478
old_val = int (final_cfs_sparse .at [cf_ix , feature ])
474
- final_cfs_sparse .at [cf_ix , feature ] += np .sign (diff )* change
475
- current_pred = self .predict_fn_for_sparsity (final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
479
+ final_cfs_sparse .at [cf_ix , feature ] += np .sign (diff ) * change
480
+ current_pred = self .predict_fn_for_sparsity (
481
+ final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
476
482
old_diff = diff
477
483
478
484
if not self .is_cf_valid (current_pred ):
@@ -505,11 +511,12 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
505
511
right = query_instance [feature ].iat [0 ]
506
512
507
513
while left <= right :
508
- current_val = left + ((right - left )/ 2 )
514
+ current_val = left + ((right - left ) / 2 )
509
515
current_val = round (current_val , decimal_prec [feature ])
510
516
511
517
final_cfs_sparse .at [cf_ix , feature ] = current_val
512
- current_pred = self .predict_fn_for_sparsity (final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
518
+ current_pred = self .predict_fn_for_sparsity (
519
+ final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
513
520
514
521
if current_val == right or current_val == left :
515
522
break
@@ -524,19 +531,20 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
524
531
right = int (final_cfs_sparse .at [cf_ix , feature ])
525
532
526
533
while right >= left :
527
- current_val = right - ((right - left )/ 2 )
534
+ current_val = right - ((right - left ) / 2 )
528
535
current_val = round (current_val , decimal_prec [feature ])
529
536
530
537
final_cfs_sparse .at [cf_ix , feature ] = current_val
531
- current_pred = self .predict_fn_for_sparsity (final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
538
+ current_pred = self .predict_fn_for_sparsity (
539
+ final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
532
540
533
541
if current_val == right or current_val == left :
534
542
break
535
543
536
544
if self .is_cf_valid (current_pred ):
537
- right = current_val - (10 ** - decimal_prec [feature ])
545
+ right = current_val - (10 ** - decimal_prec [feature ])
538
546
else :
539
- left = current_val + (10 ** - decimal_prec [feature ])
547
+ left = current_val + (10 ** - decimal_prec [feature ])
540
548
541
549
return final_cfs_sparse
542
550
@@ -578,7 +586,7 @@ def infer_target_cfs_class(self, desired_class_input, original_pred, num_output_
578
586
raise UserConfigValidationException ("Desired class not present in training data!" )
579
587
else :
580
588
raise UserConfigValidationException ("The target class for {0} could not be identified" .format (
581
- desired_class_input ))
589
+ desired_class_input ))
582
590
583
591
def infer_target_cfs_range (self , desired_range_input ):
584
592
target_range = None
@@ -597,7 +605,7 @@ def decide_cf_validity(self, model_outputs):
597
605
pred = model_outputs [i ]
598
606
if self .model .model_type == ModelTypes .Classifier :
599
607
if self .num_output_nodes == 2 : # binary
600
- pred_1 = pred [self .num_output_nodes - 1 ]
608
+ pred_1 = pred [self .num_output_nodes - 1 ]
601
609
validity [i ] = 1 if \
602
610
((self .target_cf_class == 0 and pred_1 <= self .stopping_threshold ) or
603
611
(self .target_cf_class == 1 and pred_1 >= self .stopping_threshold )) else 0
@@ -634,7 +642,7 @@ def is_cf_valid(self, model_score):
634
642
(target_cf_class == 1 and pred_1 >= self .stopping_threshold )) else False
635
643
return validity
636
644
if self .num_output_nodes == 2 : # binary
637
- pred_1 = model_score [self .num_output_nodes - 1 ]
645
+ pred_1 = model_score [self .num_output_nodes - 1 ]
638
646
validity = True if \
639
647
((target_cf_class == 0 and pred_1 <= self .stopping_threshold ) or
640
648
(target_cf_class == 1 and pred_1 >= self .stopping_threshold )) else False
@@ -710,7 +718,8 @@ def round_to_precision(self):
710
718
for ix , feature in enumerate (self .data_interface .continuous_feature_names ):
711
719
self .final_cfs_df [feature ] = self .final_cfs_df [feature ].astype (float ).round (precisions [ix ])
712
720
if self .final_cfs_df_sparse is not None :
713
- self .final_cfs_df_sparse [feature ] = self .final_cfs_df_sparse [feature ].astype (float ).round (precisions [ix ])
721
+ self .final_cfs_df_sparse [feature ] = self .final_cfs_df_sparse [feature ].astype (float ).round (
722
+ precisions [ix ])
714
723
715
724
def _check_any_counterfactuals_computed (self , cf_examples_arr ):
716
725
"""Check if any counterfactuals were generated for any query point."""
0 commit comments