diff --git a/tests/test_sklearn_permutation_importance.py b/tests/test_sklearn_permutation_importance.py index cd6b1dbe..33825045 100644 --- a/tests/test_sklearn_permutation_importance.py +++ b/tests/test_sklearn_permutation_importance.py @@ -8,6 +8,8 @@ from sklearn.pipeline import make_pipeline from sklearn.feature_selection import SelectFromModel from sklearn.linear_model import LogisticRegression +import pandas as pd +from xgboost import XGBClassifier import eli5 from eli5.sklearn import PermutationImportance @@ -164,3 +166,15 @@ def test_explain_weights(iris_train): res = format_as_all(expl, perm.wrapped_estimator_) for _expl in res: assert "petal width (cm)" in _expl + + +def test_dataframe_input_to_xgbclassifier(): + # 30 items of data, pairs of a useless feature and a predictive feature + X_np = np.array([[0,1]]*15 + [[0,2]]*15) + y_np = np.array([0]*15 + [1]*15) + X = pd.DataFrame(X_np) + y = pd.Series(y_np) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) + est = XGBClassifier() + est.fit(X_train, y_train) + perm = PermutationImportance(est).fit(X_test, y_test) \ No newline at end of file