55import pickle
66from abc import ABC , abstractmethod
77from collections .abc import Iterable
8+ from typing import Any , Dict , List , Optional , Union
89
910import numpy as np
1011import pandas as pd
@@ -152,10 +153,9 @@ def generate_counterfactuals(self, query_instances, total_CFs,
152153 cf_examples_arr = []
153154 query_instances_list = []
154155 if isinstance (query_instances , pd .DataFrame ):
155- for ix in range (query_instances .shape [0 ]):
156- query_instances_list .append (query_instances [ix :(ix + 1 )])
156+ query_instances_list = [query_instances [ix :(ix + 1 )] for ix in range (query_instances .shape [0 ])]
157157 elif isinstance (query_instances , Iterable ):
158- query_instances_list = query_instances
158+ query_instances_list = [ query_instance for query_instance in query_instances ]
159159 for query_instance in tqdm (query_instances_list ):
160160 self .data_interface .set_continuous_feature_indexes (query_instance )
161161 res = self ._generate_counterfactuals (
@@ -416,7 +416,7 @@ def feature_importance(self, query_instances, cf_examples_list=None,
416416 posthoc_sparsity_algorithm = posthoc_sparsity_algorithm ,
417417 ** kwargs ).cf_examples_list
418418 allcols = self .data_interface .categorical_feature_names + self .data_interface .continuous_feature_names
419- summary_importance = None
419+ summary_importance : Optional [ Union [ Dict [ int , float ]]] = None
420420 local_importances = None
421421 if global_importance :
422422 summary_importance = {}
@@ -532,7 +532,7 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post
532532 for feature in features_sorted :
533533 # current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names])
534534 # feat_ix = self.data_interface.continuous_feature_names.index(feature)
535- diff = query_instance [feature ].iat [0 ] - final_cfs_sparse .at [cf_ix , feature ]
535+ diff = query_instance [feature ].iat [0 ] - int ( final_cfs_sparse .at [cf_ix , feature ])
536536 if (abs (diff ) <= quantiles [feature ]):
537537 if posthoc_sparsity_algorithm == "linear" :
538538 final_cfs_sparse = self .do_linear_search (diff , decimal_prec , query_instance , cf_ix ,
@@ -561,16 +561,17 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
561561 while ((abs (diff ) > 10e-4 ) and (np .sign (diff * old_diff ) > 0 ) and
562562 self .is_cf_valid (current_pred )) and (count_steps < limit_steps_ls ):
563563
564- old_val = final_cfs_sparse .at [cf_ix , feature ]
564+ old_val = int ( final_cfs_sparse .at [cf_ix , feature ])
565565 final_cfs_sparse .at [cf_ix , feature ] += np .sign (diff )* change
566566 current_pred = self .predict_fn_for_sparsity (final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
567567 old_diff = diff
568568
569569 if not self .is_cf_valid (current_pred ):
570570 final_cfs_sparse .at [cf_ix , feature ] = old_val
571+ diff = query_instance [feature ].iat [0 ] - int (final_cfs_sparse .at [cf_ix , feature ])
571572 return final_cfs_sparse
572573
573- diff = query_instance [feature ].iat [0 ] - final_cfs_sparse .at [cf_ix , feature ]
574+ diff = query_instance [feature ].iat [0 ] - int ( final_cfs_sparse .at [cf_ix , feature ])
574575
575576 count_steps += 1
576577
@@ -580,7 +581,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
580581 """Performs a binary search between continuous features of a CF and corresponding values
581582 in query_instance until the prediction class changes."""
582583
583- old_val = final_cfs_sparse .at [cf_ix , feature ]
584+ old_val = int ( final_cfs_sparse .at [cf_ix , feature ])
584585 final_cfs_sparse .at [cf_ix , feature ] = query_instance [feature ].iat [0 ]
585586 # Prediction of the query instance
586587 current_pred = self .predict_fn_for_sparsity (final_cfs_sparse .loc [[cf_ix ]][self .data_interface .feature_names ])
@@ -593,7 +594,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
593594
594595 # move the CF values towards the query_instance
595596 if diff > 0 :
596- left = final_cfs_sparse .at [cf_ix , feature ]
597+ left = int ( final_cfs_sparse .at [cf_ix , feature ])
597598 right = query_instance [feature ].iat [0 ]
598599
599600 while left <= right :
@@ -613,7 +614,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
613614
614615 else :
615616 left = query_instance [feature ].iat [0 ]
616- right = final_cfs_sparse .at [cf_ix , feature ]
617+ right = int ( final_cfs_sparse .at [cf_ix , feature ])
617618
618619 while right >= left :
619620 current_val = right - ((right - left )/ 2 )
@@ -731,13 +732,16 @@ def is_cf_valid(self, model_score):
731732 model_score = model_score [0 ]
732733 # Converting target_cf_class to a scalar (tf/torch have it as (1,1) shape)
733734 if self .model .model_type == ModelTypes .Classifier :
734- target_cf_class = self .target_cf_class
735735 if hasattr (self .target_cf_class , "shape" ):
736736 if len (self .target_cf_class .shape ) == 1 :
737- target_cf_class = self .target_cf_class [0 ]
737+ temp_target_cf_class = self .target_cf_class [0 ]
738738 elif len (self .target_cf_class .shape ) == 2 :
739- target_cf_class = self .target_cf_class [0 ][0 ]
740- target_cf_class = int (target_cf_class )
739+ temp_target_cf_class = self .target_cf_class [0 ][0 ]
740+ else :
741+ temp_target_cf_class = int (self .target_cf_class )
742+ else :
743+ temp_target_cf_class = int (self .target_cf_class )
744+ target_cf_class = temp_target_cf_class
741745
742746 if len (model_score ) == 1 : # for tensorflow/pytorch models
743747 pred_1 = model_score [0 ]
@@ -757,6 +761,7 @@ def is_cf_valid(self, model_score):
757761 return self .target_cf_range [0 ] <= model_score and model_score <= self .target_cf_range [1 ]
758762
759763 def get_model_output_from_scores (self , model_scores ):
764+ output_type : Any = None
760765 if self .model .model_type == ModelTypes .Classifier :
761766 output_type = np .int32
762767 else :
@@ -806,17 +811,19 @@ def build_KD_tree(self, data_df_copy, desired_range, desired_class, predicted_ou
806811 data_df_copy [predicted_outcome_name ] = predictions
807812
808813 # segmenting the dataset according to outcome
809- dataset_with_predictions = None
810814 if self .model .model_type == ModelTypes .Classifier :
811815 dataset_with_predictions = data_df_copy .loc [[i == desired_class for i in predictions ]].copy ()
812816
813817 elif self .model .model_type == ModelTypes .Regressor :
814818 dataset_with_predictions = data_df_copy .loc [
815819 [desired_range [0 ] <= pred <= desired_range [1 ] for pred in predictions ]].copy ()
816820
821+ else :
822+ dataset_with_predictions = None
823+
817824 KD_tree = None
818825 # Prepares the KD trees for DiCE
819- if len (dataset_with_predictions ) > 0 :
826+ if dataset_with_predictions is not None and len (dataset_with_predictions ) > 0 :
820827 dummies = pd .get_dummies (dataset_with_predictions [self .data_interface .feature_names ])
821828 KD_tree = KDTree (dummies )
822829
0 commit comments