Skip to content

Commit 00dda52

Browse files
committed
review
Signed-off-by: giandos200 <[email protected]>
1 parent 785bc4a commit 00dda52

File tree

4 files changed

+44
-36
lines changed

4 files changed

+44
-36
lines changed

dice_ml/explainer_interfaces/dice_KD.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig
240240
# post-hoc operation on continuous features to enhance sparsity - only for public data
241241
if posthoc_sparsity_param is not None and posthoc_sparsity_param > 0 and 'data_df' in self.data_interface.__dict__:
242242
self.final_cfs_df_sparse = copy.deepcopy(self.final_cfs)
243-
self.final_cfs_df_sparse = self.do_posthoc_sparsity_enhancement(self.final_cfs_df_sparse, query_instance,
243+
self.final_cfs_df_sparse = self.do_posthoc_sparsity_enhancement(self.final_cfs_df_sparse,
244+
query_instance,
244245
posthoc_sparsity_param,
245246
posthoc_sparsity_algorithm)
246247
else:
@@ -265,9 +266,9 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig
265266
'change the query instance or the features to vary...' '; total time taken: %02d' % m,
266267
'min %02d' % s, 'sec')
267268
elif total_cfs_found == 0:
268-
print(
269-
'No Counterfactuals found for the given configuration, perhaps try with different parameters...',
270-
'; total time taken: %02d' % m, 'min %02d' % s, 'sec')
269+
print(
270+
'No Counterfactuals found for the given configuration, perhaps try with different parameters...',
271+
'; total time taken: %02d' % m, 'min %02d' % s, 'sec')
271272
else:
272273
print('Diverse Counterfactuals found! total time taken: %02d' % m, 'min %02d' % s, 'sec')
273274

dice_ml/explainer_interfaces/dice_genetic.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desir
150150
kx += 1
151151
self.cfs = np.array(row)
152152

153-
#if len(self.cfs) > self.population_size:
154-
# pass
155153
if len(self.cfs) != self.population_size:
156154
print("Pericolo Loop infinito....!!!!")
157155
remaining_cfs = self.do_random_init(
@@ -264,7 +262,7 @@ def _generate_counterfactuals(self, query_instance, total_CFs, initialization="k
264262
(see diverse_counterfactuals.py).
265263
"""
266264

267-
self.population_size = 3 * total_CFs
265+
self.population_size = 10 * total_CFs
268266

269267
self.start_time = timeit.default_timer()
270268

@@ -470,8 +468,8 @@ def find_counterfactuals(self, query_instance, desired_range, desired_class,
470468
if rest_members > 0:
471469
new_generation_2 = np.zeros((rest_members, self.data_interface.number_of_features))
472470
for new_gen_idx in range(rest_members):
473-
parent1 = random.choice(population[:max(int(len(population) / 2),1)])
474-
parent2 = random.choice(population[:max(int(len(population) / 2),1)])
471+
parent1 = random.choice(population[:max(int(len(population) / 2), 1)])
472+
parent2 = random.choice(population[:max(int(len(population) / 2), 1)])
475473
child = self.mate(parent1, parent2, features_to_vary, query_instance)
476474
new_generation_2[new_gen_idx] = child
477475

dice_ml/explainer_interfaces/explainer_base.py

+35-26
Original file line numberDiff line numberDiff line change
@@ -82,21 +82,22 @@ def generate_counterfactuals(self, query_instances, total_CFs,
8282
raise UserConfigValidationException(
8383
"The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")
8484
if total_CFs > 10:
85-
if posthoc_sparsity_algorithm == None:
85+
if posthoc_sparsity_algorithm is None:
8686
posthoc_sparsity_algorithm = 'binary'
87-
elif total_CFs >50 and posthoc_sparsity_algorithm == 'linear':
87+
elif total_CFs > 50 and posthoc_sparsity_algorithm == 'linear':
8888
import warnings
89-
warnings.warn("The number of counterfactuals (total_CFs={}) generated per query instance could take much time; "
90-
"if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to "
91-
"'binary' search!".format(total_CFs))
92-
elif posthoc_sparsity_algorithm == None:
89+
warnings.warn(
90+
"The number of counterfactuals (total_CFs={}) generated per query instance could take much time; "
91+
"if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to "
92+
"'binary' search!".format(total_CFs))
93+
elif posthoc_sparsity_algorithm is None:
9394
posthoc_sparsity_algorithm = 'linear'
9495

9596
cf_examples_arr = []
9697
query_instances_list = []
9798
if isinstance(query_instances, pd.DataFrame):
9899
for ix in range(query_instances.shape[0]):
99-
query_instances_list.append(query_instances[ix:(ix+1)])
100+
query_instances_list.append(query_instances[ix:(ix + 1)])
100101
elif isinstance(query_instances, Iterable):
101102
query_instances_list = query_instances
102103

@@ -190,11 +191,14 @@ def check_query_instance_validity(self, features_to_vary, permitted_range, query
190191

191192
if feature not in features_to_vary and permitted_range is not None:
192193
if feature in permitted_range and feature in self.data_interface.continuous_feature_names:
193-
if not permitted_range[feature][0] <= query_instance[feature].values[0] <= permitted_range[feature][1]:
194-
raise ValueError("Feature:", feature, "is outside the permitted range and isn't allowed to vary.")
194+
if not permitted_range[feature][0] <= query_instance[feature].values[0] <= permitted_range[feature][
195+
1]:
196+
raise ValueError("Feature:", feature,
197+
"is outside the permitted range and isn't allowed to vary.")
195198
elif feature in permitted_range and feature in self.data_interface.categorical_feature_names:
196199
if query_instance[feature].values[0] not in self.feature_range[feature]:
197-
raise ValueError("Feature:", feature, "is outside the permitted range and isn't allowed to vary.")
200+
raise ValueError("Feature:", feature,
201+
"is outside the permitted range and isn't allowed to vary.")
198202

199203
def local_feature_importance(self, query_instances, cf_examples_list=None,
200204
total_CFs=10,
@@ -440,12 +444,13 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post
440444
cfs_preds_sparse = []
441445

442446
for cf_ix in list(final_cfs_sparse.index):
443-
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
447+
current_pred = self.predict_fn_for_sparsity(
448+
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
444449
for feature in features_sorted:
445450
# current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names])
446451
# feat_ix = self.data_interface.continuous_feature_names.index(feature)
447452
diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
448-
if(abs(diff) <= quantiles[feature]):
453+
if (abs(diff) <= quantiles[feature]):
449454
if posthoc_sparsity_algorithm == "linear":
450455
final_cfs_sparse = self.do_linear_search(diff, decimal_prec, query_instance, cf_ix,
451456
feature, final_cfs_sparse, current_pred)
@@ -466,13 +471,14 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
466471
query_instance greedily until the prediction class changes."""
467472

468473
old_diff = diff
469-
change = (10**-decimal_prec[feature]) # the minimal possible change for a feature
474+
change = (10 ** -decimal_prec[feature]) # the minimal possible change for a feature
470475
current_pred = current_pred_orig
471476
if self.model.model_type == ModelTypes.Classifier:
472-
while((abs(diff) > 10e-4) and (np.sign(diff*old_diff) > 0) and self.is_cf_valid(current_pred)):
477+
while ((abs(diff) > 10e-4) and (np.sign(diff * old_diff) > 0) and self.is_cf_valid(current_pred)):
473478
old_val = int(final_cfs_sparse.at[cf_ix, feature])
474-
final_cfs_sparse.at[cf_ix, feature] += np.sign(diff)*change
475-
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
479+
final_cfs_sparse.at[cf_ix, feature] += np.sign(diff) * change
480+
current_pred = self.predict_fn_for_sparsity(
481+
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
476482
old_diff = diff
477483

478484
if not self.is_cf_valid(current_pred):
@@ -505,11 +511,12 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
505511
right = query_instance[feature].iat[0]
506512

507513
while left <= right:
508-
current_val = left + ((right - left)/2)
514+
current_val = left + ((right - left) / 2)
509515
current_val = round(current_val, decimal_prec[feature])
510516

511517
final_cfs_sparse.at[cf_ix, feature] = current_val
512-
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
518+
current_pred = self.predict_fn_for_sparsity(
519+
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
513520

514521
if current_val == right or current_val == left:
515522
break
@@ -524,19 +531,20 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
524531
right = int(final_cfs_sparse.at[cf_ix, feature])
525532

526533
while right >= left:
527-
current_val = right - ((right - left)/2)
534+
current_val = right - ((right - left) / 2)
528535
current_val = round(current_val, decimal_prec[feature])
529536

530537
final_cfs_sparse.at[cf_ix, feature] = current_val
531-
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
538+
current_pred = self.predict_fn_for_sparsity(
539+
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
532540

533541
if current_val == right or current_val == left:
534542
break
535543

536544
if self.is_cf_valid(current_pred):
537-
right = current_val - (10**-decimal_prec[feature])
545+
right = current_val - (10 ** -decimal_prec[feature])
538546
else:
539-
left = current_val + (10**-decimal_prec[feature])
547+
left = current_val + (10 ** -decimal_prec[feature])
540548

541549
return final_cfs_sparse
542550

@@ -578,7 +586,7 @@ def infer_target_cfs_class(self, desired_class_input, original_pred, num_output_
578586
raise UserConfigValidationException("Desired class not present in training data!")
579587
else:
580588
raise UserConfigValidationException("The target class for {0} could not be identified".format(
581-
desired_class_input))
589+
desired_class_input))
582590

583591
def infer_target_cfs_range(self, desired_range_input):
584592
target_range = None
@@ -597,7 +605,7 @@ def decide_cf_validity(self, model_outputs):
597605
pred = model_outputs[i]
598606
if self.model.model_type == ModelTypes.Classifier:
599607
if self.num_output_nodes == 2: # binary
600-
pred_1 = pred[self.num_output_nodes-1]
608+
pred_1 = pred[self.num_output_nodes - 1]
601609
validity[i] = 1 if \
602610
((self.target_cf_class == 0 and pred_1 <= self.stopping_threshold) or
603611
(self.target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else 0
@@ -634,7 +642,7 @@ def is_cf_valid(self, model_score):
634642
(target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else False
635643
return validity
636644
if self.num_output_nodes == 2: # binary
637-
pred_1 = model_score[self.num_output_nodes-1]
645+
pred_1 = model_score[self.num_output_nodes - 1]
638646
validity = True if \
639647
((target_cf_class == 0 and pred_1 <= self.stopping_threshold) or
640648
(target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else False
@@ -710,7 +718,8 @@ def round_to_precision(self):
710718
for ix, feature in enumerate(self.data_interface.continuous_feature_names):
711719
self.final_cfs_df[feature] = self.final_cfs_df[feature].astype(float).round(precisions[ix])
712720
if self.final_cfs_df_sparse is not None:
713-
self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(precisions[ix])
721+
self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(
722+
precisions[ix])
714723

715724
def _check_any_counterfactuals_computed(self, cf_examples_arr):
716725
"""Check if any counterfactuals were generated for any query point."""

tests/test_notebooks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import nbformat
1010
import pytest
1111

12-
NOTEBOOKS_PATH = "../docs/source/notebooks/"
12+
NOTEBOOKS_PATH = "docs/source/notebooks/"
1313
notebooks_list = [f.name for f in os.scandir(NOTEBOOKS_PATH) if f.name.endswith(".ipynb")]
1414
# notebooks that should not be run
1515
advanced_notebooks = [

0 commit comments

Comments
 (0)