Skip to content

Commit 86dc77f

Browse files
committed
Enanchment in CF Generation
1 parent e053592 commit 86dc77f

22 files changed

+132
-149
lines changed

dice_ml/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .data import Data
2-
from .dice import Dice
32
from .model import Model
3+
from .dice import Dice
44

55
__all__ = ["Data",
66
"Model",

dice_ml/counterfactual_explanations.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import json
2-
import os
3-
42
import jsonschema
3+
import os
54

6-
from dice_ml.constants import _SchemaVersions
7-
from dice_ml.diverse_counterfactuals import (CounterfactualExamples,
8-
_DiverseCFV2SchemaConstants)
5+
from dice_ml.diverse_counterfactuals import CounterfactualExamples
96
from dice_ml.utils.exception import UserConfigValidationException
7+
from dice_ml.diverse_counterfactuals import _DiverseCFV2SchemaConstants
8+
from dice_ml.constants import _SchemaVersions
109

1110

1211
class _CommonSchemaConstants:

dice_ml/data_interfaces/private_data_interface.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""Module containing meta data information about private data."""
22

3-
import collections
4-
import logging
53
import sys
6-
7-
import numpy as np
84
import pandas as pd
5+
import numpy as np
6+
import collections
7+
import logging
98

109
from dice_ml.data_interfaces.base_data_interface import _BaseData
1110

dice_ml/data_interfaces/public_data_interface.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
"""Module containing all required information about the interface between raw (or transformed)
22
public data and DiCE explainers."""
33

4+
import pandas as pd
5+
import numpy as np
46
import logging
57
from collections import defaultdict
68

7-
import numpy as np
8-
import pandas as pd
9-
109
from dice_ml.data_interfaces.base_data_interface import _BaseData
11-
from dice_ml.utils.exception import (SystemException,
12-
UserConfigValidationException)
10+
from dice_ml.utils.exception import SystemException, UserConfigValidationException
1311

1412

1513
class PublicData(_BaseData):
@@ -260,7 +258,7 @@ def get_valid_feature_range(self, feature_range_input, normalized=True):
260258
"""
261259
feature_range = {}
262260

263-
for _, feature_name in enumerate(self.feature_names):
261+
for idx, feature_name in enumerate(self.feature_names):
264262
feature_range[feature_name] = []
265263
if feature_name in self.continuous_feature_names:
266264
max_value = self.data_df[feature_name].max()

dice_ml/dice.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
such as RandomSampling, DiCEKD or DiCEGenetic"""
44

55
from dice_ml.constants import BackEndTypes, SamplingStrategy
6-
from dice_ml.data_interfaces.private_data_interface import PrivateData
7-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
86
from dice_ml.utils.exception import UserConfigValidationException
7+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
8+
from dice_ml.data_interfaces.private_data_interface import PrivateData
99

1010

1111
class Dice(ExplainerBase):
@@ -67,14 +67,12 @@ def decide(model_interface, method):
6767

6868
elif model_interface.backend == BackEndTypes.Tensorflow1:
6969
# pretrained Keras Sequential model with Tensorflow 1.x backend
70-
from dice_ml.explainer_interfaces.dice_tensorflow1 import \
71-
DiceTensorFlow1
70+
from dice_ml.explainer_interfaces.dice_tensorflow1 import DiceTensorFlow1
7271
return DiceTensorFlow1
7372

7473
elif model_interface.backend == BackEndTypes.Tensorflow2:
7574
# pretrained Keras Sequential model with Tensorflow 2.x backend
76-
from dice_ml.explainer_interfaces.dice_tensorflow2 import \
77-
DiceTensorFlow2
75+
from dice_ml.explainer_interfaces.dice_tensorflow2 import DiceTensorFlow2
7876
return DiceTensorFlow2
7977

8078
elif model_interface.backend == BackEndTypes.Pytorch:

dice_ml/diverse_counterfactuals.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1+
import pandas as pd
12
import copy
23
import json
3-
4-
import pandas as pd
5-
6-
from dice_ml.constants import ModelTypes, _SchemaVersions
74
from dice_ml.utils.serialize import DummyDataInterface
5+
from dice_ml.constants import _SchemaVersions, ModelTypes
86

97

108
class _DiverseCFV1SchemaConstants:
@@ -117,7 +115,6 @@ def _visualize_internal(self, display_sparse_df=True, show_only_changes=False,
117115

118116
def visualize_as_dataframe(self, display_sparse_df=True, show_only_changes=False):
119117
from IPython.display import display
120-
121118
# original instance
122119
print('Query instance (original outcome : %i)' % round(self.test_pred))
123120
display(self.test_instance_df) # works only in Jupyter notebook

dice_ml/explainer_interfaces/dice_KD.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
Module to generate counterfactual explanations from a KD-Tree
33
This code is similar to 'Interpretable Counterfactual Explanations Guided by Prototypes': https://arxiv.org/pdf/1907.02584.pdf
44
"""
5-
import copy
6-
import timeit
7-
5+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
86
import numpy as np
7+
import timeit
98
import pandas as pd
9+
import copy
1010

1111
from dice_ml import diverse_counterfactuals as exp
1212
from dice_ml.constants import ModelTypes
13-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1413

1514

1615
class DiceKD(ExplainerBase):
@@ -260,10 +259,14 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig
260259
if total_cfs_found < total_CFs:
261260
self.elapsed = timeit.default_timer() - start_time
262261
m, s = divmod(self.elapsed, 60)
263-
print('Only %d (required %d) ' % (total_cfs_found, self.total_CFs),
262+
print('Only %d (required %d) ' % (total_cfs_found, total_CFs),
264263
'Diverse Counterfactuals found for the given configuation, perhaps ',
265264
'change the query instance or the features to vary...' '; total time taken: %02d' % m,
266265
'min %02d' % s, 'sec')
266+
elif total_cfs_found == 0:
267+
print(
268+
'No Counterfactuals found for the given configuration, perhaps try with different parameters...',
269+
'; total time taken: %02d' % m, 'min %02d' % s, 'sec')
267270
else:
268271
print('Diverse Counterfactuals found! total time taken: %02d' % m, 'min %02d' % s, 'sec')
269272

dice_ml/explainer_interfaces/dice_genetic.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22
Module to generate diverse counterfactual explanations based on genetic algorithm
33
This code is similar to 'GeCo: Quality Counterfactual Explanations in Real Time': https://arxiv.org/pdf/2101.01292.pdf
44
"""
5-
import copy
6-
import random
7-
import timeit
8-
5+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
96
import numpy as np
107
import pandas as pd
8+
import random
9+
import timeit
10+
import copy
1111
from sklearn.preprocessing import LabelEncoder
1212

1313
from dice_ml import diverse_counterfactuals as exp
1414
from dice_ml.constants import ModelTypes
15-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1615

1716

1817
class DiceGenetic(ExplainerBase):
@@ -116,9 +115,8 @@ def do_random_init(self, num_inits, features_to_vary, query_instance, desired_cl
116115
def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desired_range):
117116
cfs = self.label_encode(cfs)
118117
cfs = cfs.reset_index(drop=True)
119-
120-
self.cfs = np.zeros((self.population_size, self.data_interface.number_of_features))
121-
for kx in range(self.population_size):
118+
row = []
119+
for kx in range(self.population_size*5):
122120
if kx >= len(cfs):
123121
break
124122
one_init = np.zeros(self.data_interface.number_of_features)
@@ -143,16 +141,21 @@ def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desir
143141
one_init[jx] = query_instance[jx]
144142
else:
145143
one_init[jx] = np.random.choice(self.feature_range[feature])
146-
self.cfs[kx] = one_init
144+
t = tuple(one_init)
145+
if t not in row:
146+
row.append(t)
147+
if len(row) == self.population_size:
148+
break
147149
kx += 1
150+
self.cfs = np.array(row)
148151

149-
new_array = [tuple(row) for row in self.cfs]
150-
uniques = np.unique(new_array, axis=0)
151-
152-
if len(uniques) != self.population_size:
152+
#if len(self.cfs) > self.population_size:
153+
# pass
154+
if len(self.cfs) != self.population_size:
155+
print("Pericolo Loop infinito....!!!!")
153156
remaining_cfs = self.do_random_init(
154-
self.population_size - len(uniques), features_to_vary, query_instance, desired_class, desired_range)
155-
self.cfs = np.concatenate([uniques, remaining_cfs])
157+
self.population_size - len(self.cfs), features_to_vary, query_instance, desired_class, desired_range)
158+
self.cfs = np.concatenate([self.cfs, remaining_cfs])
156159

157160
def do_cf_initializations(self, total_CFs, initialization, algorithm, features_to_vary, desired_range,
158161
desired_class,
@@ -260,7 +263,7 @@ def _generate_counterfactuals(self, query_instance, total_CFs, initialization="k
260263
(see diverse_counterfactuals.py).
261264
"""
262265

263-
self.population_size = 10 * total_CFs
266+
self.population_size = 3 * total_CFs
264267

265268
self.start_time = timeit.default_timer()
266269

@@ -514,6 +517,9 @@ def find_counterfactuals(self, query_instance, desired_range, desired_class,
514517
if len(self.final_cfs) == self.total_CFs:
515518
print('Diverse Counterfactuals found! total time taken: %02d' %
516519
m, 'min %02d' % s, 'sec')
520+
elif len(self.final_cfs) == 0:
521+
print('No Counterfactuals found for the given configuration, perhaps try with different parameters...',
522+
'; total time taken: %02d' % m, 'min %02d' % s, 'sec')
517523
else:
518524
print('Only %d (required %d) ' % (len(self.final_cfs), self.total_CFs),
519525
'Diverse Counterfactuals found for the given configuation, perhaps ',

dice_ml/explainer_interfaces/dice_pytorch.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
"""
22
Module to generate diverse counterfactual explanations based on PyTorch framework
33
"""
4-
import copy
5-
import random
6-
import timeit
4+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
5+
import torch
76

87
import numpy as np
9-
import torch
8+
import random
9+
import timeit
10+
import copy
1011

1112
from dice_ml import diverse_counterfactuals as exp
1213
from dice_ml.counterfactual_explanations import CounterfactualExplanations
13-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1414

1515

1616
class DicePyTorch(ExplainerBase):

dice_ml/explainer_interfaces/dice_random.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
Module to generate diverse counterfactual explanations based on random sampling.
44
A simple implementation.
55
"""
6-
import random
7-
import timeit
8-
6+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
97
import numpy as np
108
import pandas as pd
9+
import random
10+
import timeit
1111

1212
from dice_ml import diverse_counterfactuals as exp
1313
from dice_ml.constants import ModelTypes
14-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1514

1615

1716
class DiceRandom(ExplainerBase):
@@ -109,11 +108,17 @@ class of query_instance for binary classification.
109108
cfs_df = None
110109
candidate_cfs = pd.DataFrame(
111110
np.repeat(query_instance.values, sample_size, axis=0), columns=query_instance.columns)
112-
# Loop to change one feature at a time, then two features, and so on.
111+
# Loop to change one feature at a time ##->(NOT TRUE), then two features, and so on.
113112
for num_features_to_vary in range(1, len(self.features_to_vary)+1):
113+
# commented lines allow more values to change as num_features_to_vary increases, instead of .at you should use .loc
114+
# is deliberately left commented out to let you choose.
115+
# is slower, but more complete and still faster than genetic/KDtree
116+
# selected_features = np.random.choice(self.features_to_vary, (sample_size, num_features_to_vary), replace=True)
114117
selected_features = np.random.choice(self.features_to_vary, (sample_size, 1), replace=True)
115118
for k in range(sample_size):
116-
candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]]
119+
candidate_cfs.at[k, selected_features[k][0]] = random_instances._get_value(k, selected_features[k][0])
120+
# If you only want to change one feature, you should use _get_value
121+
# candidate_cfs.iloc[k][selected_features[k]]=random_instances.iloc[k][selected_features[k]]
117122
scores = self.predict_fn(candidate_cfs)
118123
validity = self.decide_cf_validity(scores)
119124
if sum(validity) > 0:

dice_ml/explainer_interfaces/dice_tensorflow1.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
"""
22
Module to generate diverse counterfactual explanations based on tensorflow 1.x
33
"""
4-
import collections
5-
import copy
6-
import random
7-
import timeit
4+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
5+
import tensorflow as tf
86

97
import numpy as np
10-
import tensorflow as tf
8+
import random
9+
import collections
10+
import timeit
11+
import copy
1112

1213
from dice_ml import diverse_counterfactuals as exp
1314
from dice_ml.counterfactual_explanations import CounterfactualExplanations
14-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1515

1616

1717
class DiceTensorFlow1(ExplainerBase):

dice_ml/explainer_interfaces/dice_tensorflow2.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
"""
22
Module to generate diverse counterfactual explanations based on tensorflow 2.x
33
"""
4-
import copy
5-
import random
6-
import timeit
4+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
5+
import tensorflow as tf
76

87
import numpy as np
9-
import tensorflow as tf
8+
import random
9+
import timeit
10+
import copy
1011

1112
from dice_ml import diverse_counterfactuals as exp
1213
from dice_ml.counterfactual_explanations import CounterfactualExplanations
13-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1414

1515

1616
class DiceTensorFlow2(ExplainerBase):
@@ -177,7 +177,7 @@ def do_cf_initializations(self, total_CFs, algorithm, features_to_vary):
177177
# CF initialization
178178
if len(self.cfs) != self.total_CFs:
179179
self.cfs = []
180-
for _ in range(self.total_CFs):
180+
for ix in range(self.total_CFs):
181181
one_init = [[]]
182182
for jx in range(self.minx.shape[1]):
183183
one_init[0].append(np.random.uniform(self.minx[0][jx], self.maxx[0][jx]))

0 commit comments

Comments
 (0)