Skip to content

Commit 490b7ab

Browse files
Fix incompatible types in assignment
Signed-off-by: Daiki Katsuragawa <[email protected]>
1 parent 8f17cfd commit 490b7ab

File tree

8 files changed

+51
-36
lines changed

8 files changed

+51
-36
lines changed

dice_ml/explainer_interfaces/dice_KD.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def vary_valid(self, KD_query_instance, total_CFs, features_to_vary, permitted_r
163163

164164
# TODO: this should be a user-specified parameter
165165
num_queries = min(len(self.dataset_with_predictions), total_CFs * 10)
166-
cfs = []
166+
cfs = pd.DataFrame()
167167

168168
if self.KD_tree is not None and num_queries > 0:
169169
KD_tree_output = self.KD_tree.query(KD_query_instance, num_queries)

dice_ml/explainer_interfaces/dice_genetic.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import copy
66
import random
77
import timeit
8+
from typing import Any, List, Union
89

910
import numpy as np
1011
import pandas as pd
@@ -27,7 +28,7 @@ def __init__(self, data_interface, model_interface):
2728
self.num_output_nodes = None
2829

2930
# variables required to generate CFs - see generate_counterfactuals() for more info
30-
self.cfs = []
31+
self.cfs = pd.DataFrame()
3132
self.features_to_vary = []
3233
self.cf_init_weights = [] # total_CFs, algorithm, features_to_vary
3334
self.loss_weights = [] # yloss_type, diversity_loss_type, feature_weights
@@ -343,12 +344,16 @@ def _predict_fn_custom(self, input_instance, desired_class):
343344

344345
def compute_yloss(self, cfs, desired_range, desired_class):
345346
"""Computes the first part (y-loss) of the loss function."""
346-
yloss = 0.0
347+
yloss: Any = 0.0
347348
if self.model.model_type == ModelTypes.Classifier:
348349
predicted_value = np.array(self.predict_fn_scores(cfs))
349350
if self.yloss_type == 'hinge_loss':
350351
maxvalue = np.full((len(predicted_value)), -np.inf)
351-
for c in range(self.num_output_nodes):
352+
if self.num_output_nodes is None:
353+
num_output_nodes = 0
354+
else:
355+
num_output_nodes = self.num_output_nodes
356+
for c in range(num_output_nodes):
352357
if c != desired_class:
353358
maxvalue = np.maximum(maxvalue, predicted_value[:, c])
354359
yloss = np.maximum(0, maxvalue - predicted_value[:, int(desired_class)])
@@ -429,7 +434,7 @@ def mate(self, k1, k2, features_to_vary, query_instance):
429434
def find_counterfactuals(self, query_instance, desired_range, desired_class,
430435
features_to_vary, maxiterations, thresh, verbose):
431436
"""Finds counterfactuals by generating cfs through the genetic algorithm"""
432-
population = self.cfs.copy()
437+
population: Any = self.cfs.copy()
433438
iterations = 0
434439
previous_best_loss = -np.inf
435440
current_best_loss = np.inf

dice_ml/explainer_interfaces/dice_pytorch.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import copy
55
import random
66
import timeit
7+
from typing import Any, Optional, Type, Union
78

89
import numpy as np
910
import torch
@@ -223,14 +224,16 @@ def do_optimizer_initializations(self, optimizer, learning_rate):
223224
opt_method = optimizer.split(':')[1]
224225

225226
# optimizater initialization
227+
self.optimizer: Optional[Union[torch.optim.Adam, torch.optim.RMSprop]] = None
226228
if opt_method == "adam":
227229
self.optimizer = torch.optim.Adam(self.cfs, lr=learning_rate)
228230
elif opt_method == "rmsprop":
229231
self.optimizer = torch.optim.RMSprop(self.cfs, lr=learning_rate)
230232

231233
def compute_yloss(self):
232234
"""Computes the first part (y-loss) of the loss function."""
233-
yloss = 0.0
235+
yloss: Any = 0.0
236+
criterion: Optional[Union[torch.nn.BCEWithLogitsLoss, torch.nn.ReLU]] = None
234237
for i in range(self.total_CFs):
235238
if self.yloss_type == "l2_loss":
236239
temp_loss = torch.pow((self.get_model_output(self.cfs[i]) - self.target_cf_class), 2)[0]
@@ -307,7 +310,7 @@ def compute_diversity_loss(self):
307310
def compute_regularization_loss(self):
308311
"""Adds a linear equality constraints to the loss functions -
309312
to ensure all levels of a categorical variable sums to one"""
310-
regularization_loss = 0.0
313+
regularization_loss: Any = 0.0
311314
for i in range(self.total_CFs):
312315
for v in self.encoded_categorical_feature_indexes:
313316
regularization_loss += torch.pow((torch.sum(self.cfs[i][v[0]:v[-1]+1]) - 1.0), 2)
@@ -425,7 +428,7 @@ def find_counterfactuals(self, query_instance, desired_class, optimizer, learnin
425428
test_pred = self.predict_fn(torch.tensor(query_instance).float())[0]
426429
if desired_class == "opposite":
427430
desired_class = 1.0 - np.round(test_pred)
428-
self.target_cf_class = torch.tensor(desired_class).float()
431+
self.target_cf_class: Any = torch.tensor(desired_class).float()
429432

430433
self.min_iter = min_iter
431434
self.max_iter = max_iter

dice_ml/explainer_interfaces/dice_random.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
import random
77
import timeit
8+
from typing import List, Optional, Union
89

910
import numpy as np
1011
import pandas as pd
@@ -30,10 +31,9 @@ def __init__(self, data_interface, model_interface):
3031
self.model.transformer.initialize_transform_func()
3132

3233
self.precisions = self.data_interface.get_decimal_precisions(output_type="dict")
33-
if self.data_interface.outcome_name in self.precisions:
34-
self.outcome_precision = [self.precisions[self.data_interface.outcome_name]]
35-
else:
36-
self.outcome_precision = 0
34+
self.outcome_precision = [
35+
self.precisions[self.data_interface.outcome_name]
36+
] if self.data_interface.outcome_name in self.precisions else 0
3737

3838
def _generate_counterfactuals(self, query_instance, total_CFs, desired_range=None,
3939
desired_class="opposite", permitted_range=None,

dice_ml/explainer_interfaces/dice_tensorflow2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,7 @@ def initialize_CFs(self, query_instance, init_near_query_instance=False):
341341
one_init.append(np.random.uniform(self.minx[0][i], self.maxx[0][i]))
342342
else:
343343
one_init.append(query_instance[0][i])
344-
one_init = np.array([one_init], dtype=np.float32)
345-
self.cfs[n].assign(one_init)
344+
self.cfs[n].assign(np.array([one_init], dtype=np.float32))
346345

347346
def round_off_cfs(self, assign=False):
348347
"""function for intermediate projection of CFs."""

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pickle
66
from abc import ABC, abstractmethod
77
from collections.abc import Iterable
8+
from typing import Any, Dict, List, Optional, Union
89

910
import numpy as np
1011
import pandas as pd
@@ -152,10 +153,9 @@ def generate_counterfactuals(self, query_instances, total_CFs,
152153
cf_examples_arr = []
153154
query_instances_list = []
154155
if isinstance(query_instances, pd.DataFrame):
155-
for ix in range(query_instances.shape[0]):
156-
query_instances_list.append(query_instances[ix:(ix+1)])
156+
query_instances_list = [query_instances[ix:(ix+1)] for ix in range(query_instances.shape[0])]
157157
elif isinstance(query_instances, Iterable):
158-
query_instances_list = query_instances
158+
query_instances_list = [query_instance for query_instance in query_instances]
159159
for query_instance in tqdm(query_instances_list):
160160
self.data_interface.set_continuous_feature_indexes(query_instance)
161161
res = self._generate_counterfactuals(
@@ -416,7 +416,7 @@ def feature_importance(self, query_instances, cf_examples_list=None,
416416
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm,
417417
**kwargs).cf_examples_list
418418
allcols = self.data_interface.categorical_feature_names + self.data_interface.continuous_feature_names
419-
summary_importance = None
419+
summary_importance: Optional[Union[Dict[int, float]]] = None
420420
local_importances = None
421421
if global_importance:
422422
summary_importance = {}
@@ -532,7 +532,7 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post
532532
for feature in features_sorted:
533533
# current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names])
534534
# feat_ix = self.data_interface.continuous_feature_names.index(feature)
535-
diff = query_instance[feature].iat[0] - final_cfs_sparse.at[cf_ix, feature]
535+
diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
536536
if(abs(diff) <= quantiles[feature]):
537537
if posthoc_sparsity_algorithm == "linear":
538538
final_cfs_sparse = self.do_linear_search(diff, decimal_prec, query_instance, cf_ix,
@@ -561,16 +561,17 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
561561
while((abs(diff) > 10e-4) and (np.sign(diff*old_diff) > 0) and
562562
self.is_cf_valid(current_pred)) and (count_steps < limit_steps_ls):
563563

564-
old_val = final_cfs_sparse.at[cf_ix, feature]
564+
old_val = int(final_cfs_sparse.at[cf_ix, feature])
565565
final_cfs_sparse.at[cf_ix, feature] += np.sign(diff)*change
566566
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
567567
old_diff = diff
568568

569569
if not self.is_cf_valid(current_pred):
570570
final_cfs_sparse.at[cf_ix, feature] = old_val
571+
diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
571572
return final_cfs_sparse
572573

573-
diff = query_instance[feature].iat[0] - final_cfs_sparse.at[cf_ix, feature]
574+
diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
574575

575576
count_steps += 1
576577

@@ -580,7 +581,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
580581
"""Performs a binary search between continuous features of a CF and corresponding values
581582
in query_instance until the prediction class changes."""
582583

583-
old_val = final_cfs_sparse.at[cf_ix, feature]
584+
old_val = int(final_cfs_sparse.at[cf_ix, feature])
584585
final_cfs_sparse.at[cf_ix, feature] = query_instance[feature].iat[0]
585586
# Prediction of the query instance
586587
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
@@ -593,7 +594,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
593594

594595
# move the CF values towards the query_instance
595596
if diff > 0:
596-
left = final_cfs_sparse.at[cf_ix, feature]
597+
left = int(final_cfs_sparse.at[cf_ix, feature])
597598
right = query_instance[feature].iat[0]
598599

599600
while left <= right:
@@ -613,7 +614,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
613614

614615
else:
615616
left = query_instance[feature].iat[0]
616-
right = final_cfs_sparse.at[cf_ix, feature]
617+
right = int(final_cfs_sparse.at[cf_ix, feature])
617618

618619
while right >= left:
619620
current_val = right - ((right - left)/2)
@@ -731,13 +732,16 @@ def is_cf_valid(self, model_score):
731732
model_score = model_score[0]
732733
# Converting target_cf_class to a scalar (tf/torch have it as (1,1) shape)
733734
if self.model.model_type == ModelTypes.Classifier:
734-
target_cf_class = self.target_cf_class
735735
if hasattr(self.target_cf_class, "shape"):
736736
if len(self.target_cf_class.shape) == 1:
737-
target_cf_class = self.target_cf_class[0]
737+
temp_target_cf_class = self.target_cf_class[0]
738738
elif len(self.target_cf_class.shape) == 2:
739-
target_cf_class = self.target_cf_class[0][0]
740-
target_cf_class = int(target_cf_class)
739+
temp_target_cf_class = self.target_cf_class[0][0]
740+
else:
741+
temp_target_cf_class = int(self.target_cf_class)
742+
else:
743+
temp_target_cf_class = int(self.target_cf_class)
744+
target_cf_class = temp_target_cf_class
741745

742746
if len(model_score) == 1: # for tensorflow/pytorch models
743747
pred_1 = model_score[0]
@@ -757,6 +761,7 @@ def is_cf_valid(self, model_score):
757761
return self.target_cf_range[0] <= model_score and model_score <= self.target_cf_range[1]
758762

759763
def get_model_output_from_scores(self, model_scores):
764+
output_type: Any = None
760765
if self.model.model_type == ModelTypes.Classifier:
761766
output_type = np.int32
762767
else:
@@ -806,17 +811,19 @@ def build_KD_tree(self, data_df_copy, desired_range, desired_class, predicted_ou
806811
data_df_copy[predicted_outcome_name] = predictions
807812

808813
# segmenting the dataset according to outcome
809-
dataset_with_predictions = None
810814
if self.model.model_type == ModelTypes.Classifier:
811815
dataset_with_predictions = data_df_copy.loc[[i == desired_class for i in predictions]].copy()
812816

813817
elif self.model.model_type == ModelTypes.Regressor:
814818
dataset_with_predictions = data_df_copy.loc[
815819
[desired_range[0] <= pred <= desired_range[1] for pred in predictions]].copy()
816820

821+
else:
822+
dataset_with_predictions = None
823+
817824
KD_tree = None
818825
# Prepares the KD trees for DiCE
819-
if len(dataset_with_predictions) > 0:
826+
if dataset_with_predictions is not None and len(dataset_with_predictions) > 0:
820827
dummies = pd.get_dummies(dataset_with_predictions[self.data_interface.feature_names])
821828
KD_tree = KDTree(dummies)
822829

dice_ml/explainer_interfaces/feasible_base_vae.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,9 @@ def train(self, pre_trained=False):
136136
train_loss = 0.0
137137
train_size = 0
138138

139-
train_dataset = torch.tensor(self.vae_train_feat).float()
140-
train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
139+
train_dataset = torch.utils.data.DataLoader(
140+
torch.tensor(self.vae_train_feat).float(), # type: ignore
141+
batch_size=self.batch_size, shuffle=True)
141142
for train in enumerate(train_dataset):
142143
self.cf_vae_optimizer.zero_grad()
143144

@@ -178,8 +179,7 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp
178179
final_cf_pred = []
179180
final_test_pred = []
180181
for i in range(len(query_instance)):
181-
train_x = test_dataset[i]
182-
train_x = torch.tensor(train_x).float()
182+
train_x = torch.tensor(test_dataset[i]).float()
183183
train_y = torch.argmax(self.pred_model(train_x), dim=1)
184184

185185
curr_gen_cf = []

dice_ml/explainer_interfaces/feasible_model_approx.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ def train(self, constraint_type, constraint_variables, constraint_direction, con
8181
train_loss = 0.0
8282
train_size = 0
8383

84-
train_dataset = torch.tensor(self.vae_train_feat).float()
85-
train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
84+
train_dataset = torch.utils.data.DataLoader(
85+
torch.tensor(self.vae_train_feat).float(), # type: ignore
86+
batch_size=self.batch_size, shuffle=True)
8687
for train in enumerate(train_dataset):
8788
self.cf_vae_optimizer.zero_grad()
8889

0 commit comments

Comments
 (0)