Skip to content

Commit 8f17cfd

Browse files
authored
fix int(final_cfs_sparse.at[cf_ix, feature]) in do_posthoc_sparsity_enhancement, do_linear_search and do_binary_search (#343)
Signed-off-by: An <[email protected]> Signed-off-by: An <[email protected]> Co-authored-by: An <>
1 parent 1c55f7b commit 8f17cfd

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post
532532
for feature in features_sorted:
533533
# current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names])
534534
# feat_ix = self.data_interface.continuous_feature_names.index(feature)
535-
diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
535+
diff = query_instance[feature].iat[0] - final_cfs_sparse.at[cf_ix, feature]
536536
if(abs(diff) <= quantiles[feature]):
537537
if posthoc_sparsity_algorithm == "linear":
538538
final_cfs_sparse = self.do_linear_search(diff, decimal_prec, query_instance, cf_ix,
@@ -561,17 +561,16 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
561561
while((abs(diff) > 10e-4) and (np.sign(diff*old_diff) > 0) and
562562
self.is_cf_valid(current_pred)) and (count_steps < limit_steps_ls):
563563

564-
old_val = int(final_cfs_sparse.at[cf_ix, feature])
564+
old_val = final_cfs_sparse.at[cf_ix, feature]
565565
final_cfs_sparse.at[cf_ix, feature] += np.sign(diff)*change
566566
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
567567
old_diff = diff
568568

569569
if not self.is_cf_valid(current_pred):
570570
final_cfs_sparse.at[cf_ix, feature] = old_val
571-
diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
572571
return final_cfs_sparse
573572

574-
diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
573+
diff = query_instance[feature].iat[0] - final_cfs_sparse.at[cf_ix, feature]
575574

576575
count_steps += 1
577576

@@ -581,7 +580,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
581580
"""Performs a binary search between continuous features of a CF and corresponding values
582581
in query_instance until the prediction class changes."""
583582

584-
old_val = int(final_cfs_sparse.at[cf_ix, feature])
583+
old_val = final_cfs_sparse.at[cf_ix, feature]
585584
final_cfs_sparse.at[cf_ix, feature] = query_instance[feature].iat[0]
586585
# Prediction of the query instance
587586
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
@@ -594,7 +593,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
594593

595594
# move the CF values towards the query_instance
596595
if diff > 0:
597-
left = int(final_cfs_sparse.at[cf_ix, feature])
596+
left = final_cfs_sparse.at[cf_ix, feature]
598597
right = query_instance[feature].iat[0]
599598

600599
while left <= right:
@@ -614,7 +613,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
614613

615614
else:
616615
left = query_instance[feature].iat[0]
617-
right = int(final_cfs_sparse.at[cf_ix, feature])
616+
right = final_cfs_sparse.at[cf_ix, feature]
618617

619618
while right >= left:
620619
current_val = right - ((right - left)/2)

0 commit comments

Comments
 (0)