diff --git a/python/interpret_community/common/constants.py b/python/interpret_community/common/constants.py index 09247b26..c1987fe5 100644 --- a/python/interpret_community/common/constants.py +++ b/python/interpret_community/common/constants.py @@ -159,6 +159,14 @@ class LightGBMParams(object): """Provide constants for LightGBM.""" CATEGORICAL_FEATURE = 'categorical_feature' + N_JOBS = 'n_jobs' + ALL = [CATEGORICAL_FEATURE, N_JOBS] + + +class LinearExplainableModelParams(object): + """Provide constants for LinearExplainableModel.""" + SPARSE_DATA = 'sparse_data' + ALL = [SPARSE_DATA] class ShapValuesOutput(str, Enum): diff --git a/python/interpret_community/mimic/mimic_explainer.py b/python/interpret_community/mimic/mimic_explainer.py index 1b8d08f2..8fdf1a3e 100644 --- a/python/interpret_community/mimic/mimic_explainer.py +++ b/python/interpret_community/mimic/mimic_explainer.py @@ -19,9 +19,8 @@ transform_with_datamapper from ..common.blackbox_explainer import BlackBoxExplainer - from .model_distill import _model_distill, _inverse_soft_logit -from .models import LGBMExplainableModel +from .models import LGBMExplainableModel, LinearExplainableModel from ..explanation.explanation import _create_local_explanation, _create_global_explanation, \ _aggregate_global_from_local_explanation, _aggregate_streamed_local_explanations, \ _create_raw_feats_global_explanation, _create_raw_feats_local_explanation, \ @@ -30,7 +29,7 @@ from ..dataset.dataset_wrapper import DatasetWrapper from ..common.constants import ExplainParams, ExplainType, ModelTask, \ ShapValuesOutput, MimicSerializationConstants, ExplainableModelType, \ - LightGBMParams, Defaults, Extension, ResetIndex + LightGBMParams, Defaults, Extension, ResetIndex, LinearExplainableModelParams import logging import json @@ -236,6 +235,8 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl """ if transformations is not None and explain_subset is not None: raise ValueError("explain_subset not supported with transformations") + self._validate_explainable_model_args(explainable_model=explainable_model, + explainable_model_args=explainable_model_args) self.reset_index = reset_index self._datamapper = None if transformations is not None: @@ -250,8 +251,7 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl wrapped_model, eval_ml_domain = _wrap_model(model, initialization_examples, model_task, is_function) super(MimicExplainer, self).__init__(wrapped_model, is_function=is_function, model_task=eval_ml_domain, **kwargs) - if explainable_model_args is None: - explainable_model_args = {} + if categorical_features is None: categorical_features = [] self._logger.debug('Initializing MimicExplainer') @@ -288,7 +288,6 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl # Index the categorical string columns for training data self._column_indexer = initialization_examples.string_index(columns=categorical_features) self._one_hot_encoder = None - explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features else: # One-hot-encode categoricals for models that don't support categoricals natively self._column_indexer = initialization_examples.string_index(columns=categorical_features) @@ -304,15 +303,55 @@ def __init__(self, model, initialization_examples, explainable_model, explainabl if isinstance(training_data, DenseData): training_data = training_data.data - explainable_model_args[ExplainParams.CLASSIFICATION] = self.predict_proba_flag - if self._supports_shap_values_output(explainable_model): - explainable_model_args[ExplainParams.SHAP_VALUES_OUTPUT] = shap_values_output + explainable_model_args = self._supplement_explainable_model_args( + explainable_model=explainable_model, + explainable_model_args=explainable_model_args, + categorical_features=categorical_features, + shap_values_output=shap_values_output) self.surrogate_model = _model_distill(self.function, explainable_model, training_data, original_training_data, explainable_model_args) self._method = self.surrogate_model._method self._original_eval_examples = None self._allow_all_transformations = allow_all_transformations + def _validate_explainable_model_args(self, explainable_model, explainable_model_args): + if explainable_model_args is None: + return + + if explainable_model == LGBMExplainableModel: + for linear_param in LinearExplainableModelParams.ALL: + if linear_param in explainable_model_args: + raise Exception(linear_param + + " found in params for LightGBM explainable model") + + if explainable_model == LinearExplainableModel: + for lightgbm_param in LightGBMParams.ALL: + if lightgbm_param in explainable_model_args: + raise Exception(lightgbm_param + + " found in params for Linear explainable model") + + all_supported_explainable_model_args = [LightGBMParams.ALL, LinearExplainableModelParams.ALL] + for explainable_model_arg in explainable_model_args: + if explainable_model_arg not in all_supported_explainable_model_args: + raise Exception( + "Found unsupported explainable model argument " + explainable_model_arg) + + def _supplement_explainable_model_args(self, explainable_model, explainable_model_args, + categorical_features, shap_values_output): + if explainable_model_args is None: + explainable_model_args = {} + + if explainable_model.explainable_model_type == ExplainableModelType.TREE_EXPLAINABLE_MODEL_TYPE and \ + self._supports_categoricals(explainable_model): + explainable_model_args[LightGBMParams.CATEGORICAL_FEATURE] = categorical_features + + explainable_model_args[ExplainParams.CLASSIFICATION] = self.predict_proba_flag + + if self._supports_shap_values_output(explainable_model): + explainable_model_args[ExplainParams.SHAP_VALUES_OUTPUT] = shap_values_output + + return explainable_model_args + def _get_surrogate_model_predictions(self, evaluation_examples): """Return the predictions given by the surrogate model. diff --git a/test/test_mimic_explainer.py b/test/test_mimic_explainer.py index 412299de..263c54d0 100644 --- a/test/test_mimic_explainer.py +++ b/test/test_mimic_explainer.py @@ -18,7 +18,8 @@ from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split from sys import platform -from interpret_community.common.constants import ShapValuesOutput, ModelTask +from interpret_community.common.constants import ShapValuesOutput, ModelTask, \ + LinearExplainableModelParams, LightGBMParams from interpret_community.mimic.models.lightgbm_model import LGBMExplainableModel from interpret_community.mimic.models.linear_model import LinearExplainableModel from common_utils import create_timeseries_data, LIGHTGBM_METHOD, \ @@ -540,6 +541,29 @@ def test_dense_wide_data(self, mimic_explainer): global_explanation = explainer.explain_global(df_X) assert global_explanation.method == LIGHTGBM_METHOD + @pytest.mark.parametrize("error_config", + [(LGBMExplainableModel, {LinearExplainableModelParams.SPARSE_DATA: True}), + (LinearExplainableModel, {LightGBMParams.N_JOBS: -1}), + (LinearExplainableModel, {LightGBMParams.CATEGORICAL_FEATURE: []}), + (LGBMExplainableModel, {"unsupported": True}), + (LinearExplainableModel, {"unsupported": True})]) + def test_validate_explainable_model_args(self, error_config, mimic_explainer): + num_features = 100 + num_rows = 1000 + test_size = 0.2 + X, y = make_regression(n_samples=num_rows, n_features=num_features) + x_train, x_test, y_train, _ = train_test_split(X, y, test_size=test_size, random_state=42) + + model = LinearRegression(normalize=True) + model.fit(x_train, y_train) + + explainable_model = error_config[0] + explainable_model_args = error_config[1] + with pytest.raises(Exception): + mimic_explainer(model, x_train, explainable_model, + explainable_model_args=explainable_model_args, + augment_data=False) + @property def iris_overall_expected_features(self): return [['petal length', 'petal width', 'sepal width', 'sepal length'],