diff --git a/dice_ml/data_interfaces/private_data_interface.py b/dice_ml/data_interfaces/private_data_interface.py index 3fd36856..e48c105d 100644 --- a/dice_ml/data_interfaces/private_data_interface.py +++ b/dice_ml/data_interfaces/private_data_interface.py @@ -182,11 +182,7 @@ def get_features_range(self, permitted_range_input=None, features_dict=None): ranges = {} # Getting default ranges based on the dataset - for feature in features_dict: - if type(features_dict[feature][0]) is int: # continuous feature - ranges[feature] = features_dict[feature] - else: - ranges[feature] = features_dict[feature] + ranges[feature] = features_dict[feature] feature_ranges_orig = ranges.copy() # Overwriting the ranges for a feature if input provided if permitted_range_input is not None: diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index 349d1959..811126d1 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -250,12 +250,7 @@ def setup(self, features_to_vary, permitted_range, query_instance, feature_weigh if features_to_vary == 'all': features_to_vary = self.data_interface.feature_names - if permitted_range is None: # use the precomputed default - self.feature_range = self.data_interface.permitted_range - feature_ranges_orig = self.feature_range - else: # compute the new ranges based on user input - self.feature_range, feature_ranges_orig = self.data_interface.get_features_range(permitted_range) - + self.feature_range, feature_ranges_orig = self.data_interface.get_features_range(permitted_range) self.check_query_instance_validity(features_to_vary, permitted_range, query_instance, feature_ranges_orig) return features_to_vary diff --git a/tests/conftest.py b/tests/conftest.py index fdc36b1d..0498b6bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -411,6 +411,34 @@ def _load_custom_vars_dataset_model(): return model +def _load_adult_income_binary_model(): + dataset = helpers.load_adult_income_dataset() + X_train = dataset.drop('income', axis=1) + y_train = dataset["income"] + num_feature_names = ["age", "hours_per_week"] + cat_feature_names = X_train.columns.difference(num_feature_names) + model = create_complex_classification_pipeline( + X_train, y_train, num_feature_names, cat_feature_names) + return model + + +def sample_adult_income_custom_query_11(): + """ + Returns multiple query instance for adult income dataset + """ + data_point = 2 + query_instances = pd.DataFrame({'age': [22]*data_point, + 'workclass': ['Private']*data_point, + 'education': ['HS-grad']*data_point, + 'marital_status': ['Single']*data_point, + 'occupation': ['Service']*data_point, + 'race': ['White']*data_point, + 'gender': ['Female']*data_point, + 'hours_per_week': [45]*data_point}, + index=list(range(data_point))) + return query_instances + + @pytest.fixture(scope='session') def sample_adultincome_query(): """ diff --git a/tests/test_dice_interface/test_explainer_base.py b/tests/test_dice_interface/test_explainer_base.py index e3b63892..e7386338 100644 --- a/tests/test_dice_interface/test_explainer_base.py +++ b/tests/test_dice_interface/test_explainer_base.py @@ -12,7 +12,10 @@ from dice_ml.explainer_interfaces.explainer_base import ExplainerBase from dice_ml.utils import helpers -from ..conftest import _load_custom_testing_binary_model +from ..conftest import (private_data_object, + sample_adult_income_custom_query_11, + _load_adult_income_binary_model, + _load_custom_testing_binary_model) @pytest.mark.parametrize("method", ['random', 'genetic', 'kdtree']) @@ -349,6 +352,20 @@ def test_cfs_type_consistency( assert cf_explanations.cf_examples_list[0].final_cfs_df[col].dtype == sample_custom_query[col].dtype if cf_explanations.cf_examples_list[0].final_cfs_df_sparse is not None: assert cf_explanations.cf_examples_list[0].final_cfs_df_sparse[col].dtype == sample_custom_query[col].dtype + + @pytest.mark.parametrize("method", ["genetic"]) + def test_genetic_private_data(method): + d = private_data_object() + query = sample_adult_income_custom_query_11() + model = _load_adult_income_binary_model() + m = dice_ml.Model(model=model, backend='sklearn') + exp = dice_ml.Dice(d, m, method=method) + + return exp.generate_counterfactuals( + query_instances=query, + total_CFs=1, + desired_class="opposite", + initialization="random") @pytest.mark.parametrize("method", ['random', 'genetic', 'kdtree'])