From 40a7d314e673c930a98310d41ef8894b86e107c6 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Mon, 14 Dec 2020 18:54:55 -0800 Subject: [PATCH 1/3] Add validations for explainable model arguments in MimicExplainer Signed-off-by: Gaurav Gupta --- .../interpret_community/common/constants.py | 8 +++ .../mimic/mimic_explainer.py | 50 ++++++++++++++++--- test/test_mimic_explainer.py | 24 ++++++++- 3 files changed, 73 insertions(+), 9 deletions(-) diff --git a/python/interpret_community/common/constants.py b/python/interpret_community/common/constants.py index 09247b26..c1987fe5 100644 --- a/python/interpret_community/common/constants.py +++ b/python/interpret_community/common/constants.py @@ -159,6 +159,14 @@ class LightGBMParams(object): """Provide constants for LightGBM.""" CATEGORICAL_FEATURE = 'categorical_feature' + N_JOBS = 'n_jobs' + ALL = [CATEGORICAL_FEATURE, N_JOBS] + + +class LinearExplainableModelParams(object): + """Provide constants for LinearExplainableModel.""" + SPARSE_DATA = 'sparse_data' + ALL = [SPARSE_DATA] class ShapValuesOutput(str, Enum): diff --git a/python/interpret_community/mimic/mimic_explainer.py b/python/interpret_community/mimic/mimic_explainer.py index e5bb6703..8d68cd87 100644 --- a/python/interpret_community/mimic/mimic_explainer.py +++ b/python/interpret_community/mimic/mimic_explainer.py @@ -21,7 +21,7 @@ from ..common.blackbox_explainer import BlackBoxExplainer from .model_distill import _model_distill -from .models import LGBMExplainableModel +from .models import LGBMExplainableModel, LinearExplainableModel from ..explanation.explanation import _create_local_explanation, _create_global_explanation, \ _aggregate_global_from_local_explanation, _aggregate_streamed_local_explanations, \ _create_raw_feats_global_explanation, _create_raw_feats_local_explanation, \ @@ -30,7 +30,7 @@ from ..dataset.dataset_wrapper import DatasetWrapper from ..common.constants import ExplainParams, ExplainType, ModelTask, \ ShapValuesOutput, MimicSerializationConstants, ExplainableModelType, \ - LightGBMParams, Defaults, Extension, ResetIndex + LightGBMParams, Defaults, Extension, ResetIndex, LinearExplainableModelParams import logging import json @@ -236,6 +236,8 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl """ if transformations is not None and explain_subset is not None: raise ValueError("explain_subset not supported with transformations") + self._validate_explainable_model_args(explainable_model=explainable_model, + explainable_model_args=explainable_model_args) self.reset_index = reset_index self._datamapper = None if transformations is not None: @@ -250,8 +252,7 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl wrapped_model, eval_ml_domain = _wrap_model(model, initialization_examples, model_task, is_function) super(MimicExplainer, self).__init__(wrapped_model, is_function=is_function, model_task=eval_ml_domain, **kwargs) - if explainable_model_args is None: - explainable_model_args = {} + if categorical_features is None: categorical_features = [] self._logger.debug('Initializing MimicExplainer') @@ -288,7 +289,6 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl # Index the categorical string columns for training data self._column_indexer = initialization_examples.string_index(columns=categorical_features) self._one_hot_encoder = None - explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features else: # One-hot-encode categoricals for models that don't support categoricals natively self._column_indexer = initialization_examples.string_index(columns=categorical_features) @@ -304,15 +304,49 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl if isinstance(training_data, DenseData): training_data = training_data.data - explainable_model_args[ExplainParams.CLASSIFICATION] = self.predict_proba_flag - if self._supports_shap_values_output(explainable_model): - explainable_model_args[ExplainParams.SHAP_VALUES_OUTPUT] = shap_values_output + explainable_model_args = self._supplement_explainable_model_args( + explainable_model=explainable_model, + explainable_model_args=explainable_model_args, + categorical_features=categorical_features, + shap_values_output=shap_values_output) self.surrogate_model = _model_distill(self.function, explainable_model, training_data, original_training_data, explainable_model_args) self._method = self.surrogate_model._method self._original_eval_examples = None self._allow_all_transformations = allow_all_transformations + def _validate_explainable_model_args(self, explainable_model, explainable_model_args): + if explainable_model_args is None: + return + + if isinstance(explainable_model, LGBMExplainableModel): + for linear_param in LinearExplainableModelParams.ALL: + if linear_param in explainable_model_args: + raise Exception(linear_param + + " found in params for LightGBM explainable model") + + if isinstance(explainable_model, LinearExplainableModel): + for lightgbm_param in LightGBMParams.ALL: + if lightgbm_param in explainable_model_args: + raise Exception(lightgbm_param + + " found in params for Linear explainable model") + + def _supplement_explainable_model_args(self, explainable_model, explainable_model_args, + categorical_features, shap_values_output): + if explainable_model_args is None: + explainable_model_args = {} + + if explainable_model.explainable_model_type == ExplainableModelType.TREE_EXPLAINABLE_MODEL_TYPE and \ + self._supports_categoricals(explainable_model): + explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features + + explainable_model_args[ExplainParams.CLASSIFICATION] = self.predict_proba_flag + + if self._supports_shap_values_output(explainable_model): + explainable_model_args[ExplainParams.SHAP_VALUES_OUTPUT] = shap_values_output + + return explainable_model_args + def _supports_categoricals(self, explainable_model): return issubclass(explainable_model, LGBMExplainableModel) diff --git a/test/test_mimic_explainer.py b/test/test_mimic_explainer.py index 76ef5b3c..61848188 100644 --- a/test/test_mimic_explainer.py +++ b/test/test_mimic_explainer.py @@ -18,7 +18,8 @@ from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split from sys import platform -from interpret_community.common.constants import ShapValuesOutput, ModelTask +from interpret_community.common.constants import ShapValuesOutput, ModelTask, \ + LinearExplainableModelParams, LightGBMParams from interpret_community.mimic.models.lightgbm_model import LGBMExplainableModel from interpret_community.mimic.models.linear_model import LinearExplainableModel from common_utils import create_sklearn_svm_classifier, create_sklearn_linear_regressor, \ @@ -523,6 +524,27 @@ def test_dense_wide_data(self, mimic_explainer): global_explanation = explainer.explain_global(df_X) assert global_explanation.method == LIGHTGBM_METHOD + @pytest.mark.parametrize("error_config", + [(LGBMExplainableModel, {LinearExplainableModelParams.SPARSE_DATA: True}), + (LinearExplainableModel, {LightGBMParams.N_JOBS: -1}), + (LinearExplainableModel, {LightGBMParams.CATEGORICAL_FEATURE: []})]) + def test_validate_explainable_model_args(self, error_config): + num_features = 100 + num_rows = 1000 + test_size = 0.2 + X, y = make_regression(n_samples=num_rows, n_features=num_features) + x_train, x_test, y_train, _ = train_test_split(X, y, test_size=test_size, random_state=42) + + model = LinearRegression(normalize=True) + model.fit(x_train, y_train) + + explainable_model = error_config[0] + explainable_model_args = error_config[1] + with pytest.raises(Exception): + mimic_explainer(model, x_train, explainable_model, + explainable_model_args=explainable_model_args, + transformations=transformations, augment_data=False) + @property def iris_overall_expected_features(self): return [['petal length', 'petal width', 'sepal width', 'sepal length'], From 40e04dfa79f4da4fc7fd50e2400862e030cecdd0 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Tue, 15 Dec 2020 06:33:02 -0800 Subject: [PATCH 2/3] Fix broken gates Signed-off-by: Gaurav Gupta --- python/interpret_community/mimic/mimic_explainer.py | 4 ++-- test/test_mimic_explainer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/interpret_community/mimic/mimic_explainer.py b/python/interpret_community/mimic/mimic_explainer.py index 8d68cd87..d63a2a98 100644 --- a/python/interpret_community/mimic/mimic_explainer.py +++ b/python/interpret_community/mimic/mimic_explainer.py @@ -319,13 +319,13 @@ def _validate_explainable_model_args(self, explainable_model, explainable_model_ if explainable_model_args is None: return - if isinstance(explainable_model, LGBMExplainableModel): + if explainable_model == LGBMExplainableModel: for linear_param in LinearExplainableModelParams.ALL: if linear_param in explainable_model_args: raise Exception(linear_param + " found in params for LightGBM explainable model") - if isinstance(explainable_model, LinearExplainableModel): + if explainable_model == LinearExplainableModel: for lightgbm_param in LightGBMParams.ALL: if lightgbm_param in explainable_model_args: raise Exception(lightgbm_param + diff --git a/test/test_mimic_explainer.py b/test/test_mimic_explainer.py index 61848188..11785570 100644 --- a/test/test_mimic_explainer.py +++ b/test/test_mimic_explainer.py @@ -528,7 +528,7 @@ def test_dense_wide_data(self, mimic_explainer): [(LGBMExplainableModel, {LinearExplainableModelParams.SPARSE_DATA: True}), (LinearExplainableModel, {LightGBMParams.N_JOBS: -1}), (LinearExplainableModel, {LightGBMParams.CATEGORICAL_FEATURE: []})]) - def test_validate_explainable_model_args(self, error_config): + def test_validate_explainable_model_args(self, error_config, mimic_explainer): num_features = 100 num_rows = 1000 test_size = 0.2 @@ -543,7 +543,7 @@ def test_validate_explainable_model_args(self, error_config): with pytest.raises(Exception): mimic_explainer(model, x_train, explainable_model, explainable_model_args=explainable_model_args, - transformations=transformations, augment_data=False) + augment_data=False) @property def iris_overall_expected_features(self): From cd75e886d67aed73d8665fcf3effdc06ed86e4ad Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Tue, 15 Dec 2020 13:04:43 -0800 Subject: [PATCH 3/3] Add more surrogate model params validation Signed-off-by: Gaurav Gupta --- python/interpret_community/mimic/mimic_explainer.py | 6 ++++++ test/test_mimic_explainer.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/interpret_community/mimic/mimic_explainer.py b/python/interpret_community/mimic/mimic_explainer.py index d63a2a98..d785e6e6 100644 --- a/python/interpret_community/mimic/mimic_explainer.py +++ b/python/interpret_community/mimic/mimic_explainer.py @@ -331,6 +331,12 @@ def _validate_explainable_model_args(self, explainable_model, explainable_model_ raise Exception(lightgbm_param + " found in params for Linear explainable model") + all_supported_explainable_model_args = [LightGBMParams.ALL, LinearExplainableModelParams.ALL] + for explainable_model_arg in explainable_model_args: + if explainable_model_arg not in all_supported_explainable_model_args: + raise Exception( + "Found unsupported explainable model argument " + explainable_model_arg) + def _supplement_explainable_model_args(self, explainable_model, explainable_model_args, categorical_features, shap_values_output): if explainable_model_args is None: diff --git a/test/test_mimic_explainer.py b/test/test_mimic_explainer.py index 11785570..84d532cf 100644 --- a/test/test_mimic_explainer.py +++ b/test/test_mimic_explainer.py @@ -527,7 +527,9 @@ def test_dense_wide_data(self, mimic_explainer): @pytest.mark.parametrize("error_config", [(LGBMExplainableModel, {LinearExplainableModelParams.SPARSE_DATA: True}), (LinearExplainableModel, {LightGBMParams.N_JOBS: -1}), - (LinearExplainableModel, {LightGBMParams.CATEGORICAL_FEATURE: []})]) + (LinearExplainableModel, {LightGBMParams.CATEGORICAL_FEATURE: []}), + (LGBMExplainableModel, {"unsupported": True}), + (LinearExplainableModel, {"unsupported": True})]) def test_validate_explainable_model_args(self, error_config, mimic_explainer): num_features = 100 num_rows = 1000