Skip to content

Commit b440eb1

Browse files
committed
import update
1 parent 2a88433 commit b440eb1

21 files changed

+132
-86
lines changed

dice_ml/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .data import Data
2-
from .model import Model
32
from .dice import Dice
3+
from .model import Model
44

55
__all__ = ["Data",
66
"Model",

dice_ml/counterfactual_explanations.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import json
2-
import jsonschema
32
import os
43

5-
from dice_ml.diverse_counterfactuals import CounterfactualExamples
6-
from dice_ml.utils.exception import UserConfigValidationException
7-
from dice_ml.diverse_counterfactuals import _DiverseCFV2SchemaConstants
4+
import jsonschema
5+
86
from dice_ml.constants import _SchemaVersions
7+
from dice_ml.diverse_counterfactuals import (CounterfactualExamples,
8+
_DiverseCFV2SchemaConstants)
9+
from dice_ml.utils.exception import UserConfigValidationException
910

1011

1112
class _CommonSchemaConstants:

dice_ml/data_interfaces/private_data_interface.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Module containing meta data information about private data."""
22

3-
import sys
4-
import pandas as pd
5-
import numpy as np
63
import collections
74
import logging
5+
import sys
6+
7+
import numpy as np
8+
import pandas as pd
89

910
from dice_ml.data_interfaces.base_data_interface import _BaseData
1011

dice_ml/data_interfaces/public_data_interface.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""Module containing all required information about the interface between raw (or transformed)
22
public data and DiCE explainers."""
33

4-
import pandas as pd
5-
import numpy as np
64
import logging
75
from collections import defaultdict
86

9-
from dice_ml.data_interfaces.base_data_interface import _BaseData
10-
from dice_ml.utils.exception import SystemException, UserConfigValidationException
7+
import numpy as np
8+
import pandas as pd
119

10+
from dice_ml.data_interfaces.base_data_interface import _BaseData
11+
from dice_ml.utils.exception import (SystemException,
12+
UserConfigValidationException)
1213

1314
class PublicData(_BaseData):
1415
"""A data interface for public data. This class is an interface to DiCE explainers
@@ -258,7 +259,7 @@ def get_valid_feature_range(self, feature_range_input, normalized=True):
258259
"""
259260
feature_range = {}
260261

261-
for idx, feature_name in enumerate(self.feature_names):
262+
for _, feature_name in enumerate(self.feature_names):
262263
feature_range[feature_name] = []
263264
if feature_name in self.continuous_feature_names:
264265
max_value = self.data_df[feature_name].max()

dice_ml/dice.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
such as RandomSampling, DiCEKD or DiCEGenetic"""
44

55
from dice_ml.constants import BackEndTypes, SamplingStrategy
6-
from dice_ml.utils.exception import UserConfigValidationException
76
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
87
from dice_ml.data_interfaces.private_data_interface import PrivateData
8+
from dice_ml.utils.exception import UserConfigValidationException
99

1010

1111
class Dice(ExplainerBase):
@@ -67,12 +67,14 @@ def decide(model_interface, method):
6767

6868
elif model_interface.backend == BackEndTypes.Tensorflow1:
6969
# pretrained Keras Sequential model with Tensorflow 1.x backend
70-
from dice_ml.explainer_interfaces.dice_tensorflow1 import DiceTensorFlow1
70+
from dice_ml.explainer_interfaces.dice_tensorflow1 import \
71+
DiceTensorFlow1
7172
return DiceTensorFlow1
7273

7374
elif model_interface.backend == BackEndTypes.Tensorflow2:
7475
# pretrained Keras Sequential model with Tensorflow 2.x backend
75-
from dice_ml.explainer_interfaces.dice_tensorflow2 import DiceTensorFlow2
76+
from dice_ml.explainer_interfaces.dice_tensorflow2 import \
77+
DiceTensorFlow2
7678
return DiceTensorFlow2
7779

7880
elif model_interface.backend == BackEndTypes.Pytorch:

dice_ml/diverse_counterfactuals.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
import pandas as pd
21
import copy
32
import json
3+
4+
import pandas as pd
5+
6+
from dice_ml.constants import ModelTypes, _SchemaVersions
47
from dice_ml.utils.serialize import DummyDataInterface
5-
from dice_ml.constants import _SchemaVersions, ModelTypes
68

79

810
class _DiverseCFV1SchemaConstants:
@@ -115,6 +117,7 @@ def _visualize_internal(self, display_sparse_df=True, show_only_changes=False,
115117

116118
def visualize_as_dataframe(self, display_sparse_df=True, show_only_changes=False):
117119
from IPython.display import display
120+
118121
# original instance
119122
print('Query instance (original outcome : %i)' % round(self.test_pred))
120123
display(self.test_instance_df) # works only in Jupyter notebook

dice_ml/explainer_interfaces/dice_KD.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
Module to generate counterfactual explanations from a KD-Tree
33
This code is similar to 'Interpretable Counterfactual Explanations Guided by Prototypes': https://arxiv.org/pdf/1907.02584.pdf
44
"""
5-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
6-
import numpy as np
5+
import copy
76
import timeit
7+
8+
import numpy as np
89
import pandas as pd
9-
import copy
1010

1111
from dice_ml import diverse_counterfactuals as exp
1212
from dice_ml.constants import ModelTypes
13+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1314

1415

1516
class DiceKD(ExplainerBase):

dice_ml/explainer_interfaces/dice_genetic.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
Module to generate diverse counterfactual explanations based on genetic algorithm
33
This code is similar to 'GeCo: Quality Counterfactual Explanations in Real Time': https://arxiv.org/pdf/2101.01292.pdf
44
"""
5-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
6-
import numpy as np
7-
import pandas as pd
5+
import copy
86
import random
97
import timeit
10-
import copy
8+
9+
import numpy as np
10+
import pandas as pd
1111
from sklearn.preprocessing import LabelEncoder
1212

1313
from dice_ml import diverse_counterfactuals as exp
1414
from dice_ml.constants import ModelTypes
15+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1516

1617

1718
class DiceGenetic(ExplainerBase):

dice_ml/explainer_interfaces/dice_pytorch.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
"""
22
Module to generate diverse counterfactual explanations based on PyTorch framework
33
"""
4-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
5-
import torch
6-
7-
import numpy as np
4+
import copy
85
import random
96
import timeit
10-
import copy
7+
import numpy as np
8+
9+
import torch
1110

1211
from dice_ml import diverse_counterfactuals as exp
1312
from dice_ml.counterfactual_explanations import CounterfactualExplanations
13+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1414

1515

1616
class DicePyTorch(ExplainerBase):

dice_ml/explainer_interfaces/dice_random.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
Module to generate diverse counterfactual explanations based on random sampling.
44
A simple implementation.
55
"""
6-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
7-
import numpy as np
8-
import pandas as pd
96
import random
107
import timeit
118

9+
import numpy as np
10+
import pandas as pd
11+
1212
from dice_ml import diverse_counterfactuals as exp
1313
from dice_ml.constants import ModelTypes
14+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1415

1516

1617
class DiceRandom(ExplainerBase):

dice_ml/explainer_interfaces/dice_tensorflow1.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
"""
22
Module to generate diverse counterfactual explanations based on tensorflow 1.x
33
"""
4-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
5-
import tensorflow as tf
6-
7-
import numpy as np
8-
import random
94
import collections
10-
import timeit
115
import copy
6+
import random
7+
import timeit
8+
9+
import numpy as np
10+
import tensorflow as tf
1211

1312
from dice_ml import diverse_counterfactuals as exp
1413
from dice_ml.counterfactual_explanations import CounterfactualExplanations
14+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1515

1616

1717
class DiceTensorFlow1(ExplainerBase):

dice_ml/explainer_interfaces/dice_tensorflow2.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
"""
22
Module to generate diverse counterfactual explanations based on tensorflow 2.x
33
"""
4-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
5-
import tensorflow as tf
6-
7-
import numpy as np
4+
import copy
85
import random
96
import timeit
10-
import copy
7+
8+
import numpy as np
9+
import tensorflow as tf
1110

1211
from dice_ml import diverse_counterfactuals as exp
1312
from dice_ml.counterfactual_explanations import CounterfactualExplanations
13+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
1414

1515

1616
class DiceTensorFlow2(ExplainerBase):
@@ -177,7 +177,7 @@ def do_cf_initializations(self, total_CFs, algorithm, features_to_vary):
177177
# CF initialization
178178
if len(self.cfs) != self.total_CFs:
179179
self.cfs = []
180-
for ix in range(self.total_CFs):
180+
for _ in range(self.total_CFs):
181181
one_init = [[]]
182182
for jx in range(self.minx.shape[1]):
183183
one_init[0].append(np.random.uniform(self.minx[0][jx], self.maxx[0][jx]))

dice_ml/explainer_interfaces/explainer_base.py

+41-13
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22
Subclasses implement interfaces for different ML frameworks such as TensorFlow or PyTorch.
33
All methods are in dice_ml.explainer_interfaces"""
44

5-
import warnings
65
from abc import ABC, abstractmethod
6+
from collections.abc import Iterable
7+
78
import numpy as np
89
import pandas as pd
10+
from sklearn.neighbors import KDTree
911
from tqdm import tqdm
1012

11-
from collections.abc import Iterable
12-
from sklearn.neighbors import KDTree
13+
from dice_ml.constants import ModelTypes
1314
from dice_ml.counterfactual_explanations import CounterfactualExplanations
1415
from dice_ml.utils.exception import UserConfigValidationException
15-
from dice_ml.constants import ModelTypes
1616

1717

1818
class ExplainerBase(ABC):
@@ -85,6 +85,7 @@ def generate_counterfactuals(self, query_instances, total_CFs,
8585
if posthoc_sparsity_algorithm == None:
8686
posthoc_sparsity_algorithm = 'binary'
8787
elif total_CFs >50 and posthoc_sparsity_algorithm == 'linear':
88+
import warnings
8889
warnings.warn("The number of counterfactuals (total_CFs={}) generated per query instance could take much time; "
8990
"if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to "
9091
"'binary' search!".format(total_CFs))
@@ -98,6 +99,7 @@ def generate_counterfactuals(self, query_instances, total_CFs,
9899
query_instances_list.append(query_instances[ix:(ix+1)])
99100
elif isinstance(query_instances, Iterable):
100101
query_instances_list = query_instances
102+
101103
for query_instance in tqdm(query_instances_list):
102104
self.data_interface.set_continuous_feature_indexes(query_instance)
103105
res = self._generate_counterfactuals(
@@ -112,6 +114,9 @@ def generate_counterfactuals(self, query_instances, total_CFs,
112114
verbose=verbose,
113115
**kwargs)
114116
cf_examples_arr.append(res)
117+
118+
self._check_any_counterfactuals_computed(cf_examples_arr=cf_examples_arr)
119+
115120
return CounterfactualExplanations(cf_examples_list=cf_examples_arr)
116121

117122
@abstractmethod
@@ -217,10 +222,12 @@ def local_feature_importance(self, query_instances, cf_examples_list=None,
217222
if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]):
218223
raise UserConfigValidationException(
219224
"The number of counterfactuals generated per query instance should be "
220-
"greater than or equal to 10")
225+
"greater than or equal to 10 to compute feature importance for all query points")
221226
elif total_CFs < 10:
222-
raise UserConfigValidationException("The number of counterfactuals generated per "
223-
"query instance should be greater than or equal to 10")
227+
raise UserConfigValidationException(
228+
"The number of counterfactuals requested per "
229+
"query instance should be greater than or equal to 10 "
230+
"to compute feature importance for all query points")
224231
importances = self.feature_importance(
225232
query_instances,
226233
cf_examples_list=cf_examples_list,
@@ -261,16 +268,25 @@ def global_feature_importance(self, query_instances, cf_examples_list=None,
261268
input, and the global feature importance summarized over all inputs.
262269
"""
263270
if query_instances is not None and len(query_instances) < 10:
264-
raise UserConfigValidationException("The number of query instances should be greater than or equal to 10")
271+
raise UserConfigValidationException(
272+
"The number of query instances should be greater than or equal to 10 "
273+
"to compute global feature importance over all query points")
265274
if cf_examples_list is not None:
266-
if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]):
275+
if len(cf_examples_list) < 10:
276+
raise UserConfigValidationException(
277+
"The number of points for which counterfactuals generated should be "
278+
"greater than or equal to 10 "
279+
"to compute global feature importance")
280+
elif any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]):
267281
raise UserConfigValidationException(
268282
"The number of counterfactuals generated per query instance should be "
269-
"greater than or equal to 10")
283+
"greater than or equal to 10"
284+
"to compute global feature importance over all query points")
270285
elif total_CFs < 10:
271286
raise UserConfigValidationException(
272287
"The number of counterfactuals generated per query instance should be greater "
273-
"than or equal to 10")
288+
"than or equal to 10"
289+
"to compute global feature importance over all query points")
274290
importances = self.feature_importance(
275291
query_instances,
276292
cf_examples_list=cf_examples_list,
@@ -349,7 +365,7 @@ def feature_importance(self, query_instances, cf_examples_list=None,
349365
continue
350366

351367
per_query_point_cfs = 0
352-
for index, row in df.iterrows():
368+
for _, row in df.iterrows():
353369
per_query_point_cfs += 1
354370
for col in self.data_interface.continuous_feature_names:
355371
if not np.isclose(org_instance[col].iat[0], row[col]):
@@ -530,7 +546,7 @@ def misc_init(self, stopping_threshold, desired_class, desired_range, test_pred)
530546
self.target_cf_class = np.array(
531547
[[self.infer_target_cfs_class(desired_class, test_pred, self.num_output_nodes)]],
532548
dtype=np.float32)
533-
desired_class = self.target_cf_class[0][0]
549+
desired_class = int(self.target_cf_class[0][0])
534550
if self.target_cf_class == 0 and self.stopping_threshold > 0.5:
535551
self.stopping_threshold = 0.25
536552
elif self.target_cf_class == 1 and self.stopping_threshold < 0.5:
@@ -695,3 +711,15 @@ def round_to_precision(self):
695711
self.final_cfs_df[feature] = self.final_cfs_df[feature].astype(float).round(precisions[ix])
696712
if self.final_cfs_df_sparse is not None:
697713
self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(precisions[ix])
714+
715+
def _check_any_counterfactuals_computed(self, cf_examples_arr):
716+
"""Check if any counterfactuals were generated for any query point."""
717+
no_cf_generated = True
718+
# Check if any counterfactuals were generated for any query point
719+
for cf_examples in cf_examples_arr:
720+
if cf_examples.final_cfs_df is not None and len(cf_examples.final_cfs_df) > 0:
721+
no_cf_generated = False
722+
break
723+
if no_cf_generated:
724+
raise UserConfigValidationException(
725+
"No counterfactuals found for any of the query points! Kindly check your configuration.")

0 commit comments

Comments
 (0)