Skip to content

Commit 48c52b5

Browse files
committed
Fixing sklearn error when using RandomizedSearchCV
1 parent 2a76307 commit 48c52b5

File tree

2 files changed

+76
-2
lines changed

2 files changed

+76
-2
lines changed

python/interpret-core/interpret/glassbox/linear.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from abc import abstractmethod
1111
from sklearn.base import is_classifier
1212
import numpy as np
13-
from sklearn.base import ClassifierMixin, RegressorMixin
13+
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
1414
from sklearn.linear_model import LogisticRegression as SKLogistic
1515
from sklearn.linear_model import Lasso as SKLinear
1616

1717

18-
class BaseLinear:
18+
class BaseLinear(BaseEstimator):
1919
""" Base linear model.
2020
2121
Currently wrapper around linear models in scikit-learn.
@@ -43,11 +43,26 @@ def __init__(
4343
self.linear_class = linear_class
4444
self.kwargs = kwargs
4545

46+
for key, value in self.kwargs.items():
47+
setattr(self, key, value)
48+
4649
@abstractmethod
4750
def _model(self):
4851
# This method should be overridden.
4952
return None
5053

54+
# get_params and set_params are usually inherited from BaseEstimator, but they will
55+
# fail here due to the **kwargs in the __init__. Therefore, we implement them.
56+
def get_params(self, deep = True):
57+
return {param: getattr(self, param)
58+
for param in self.kwargs}
59+
60+
def set_params(self, **parameters):
61+
for parameter, value in parameters.items():
62+
setattr(self, parameter, value)
63+
64+
return self
65+
5166
def fit(self, X, y):
5267
""" Fits model to provided instances.
5368

python/interpret-core/interpret/glassbox/test/test_linear.py

+59
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sklearn.datasets import load_breast_cancer, load_boston
66
from sklearn.linear_model import LogisticRegression as SKLogistic
77
from sklearn.linear_model import Lasso as SKLinear
8+
from sklearn.model_selection import RandomizedSearchCV
89
import numpy as np
910

1011

@@ -38,6 +39,35 @@ def test_linear_regression():
3839
assert global_viz is not None
3940

4041

42+
def test_linear_regression_sklearn_compatibility():
43+
boston = load_boston()
44+
X, y = boston.data, boston.target
45+
46+
distributions = {
47+
'max_iter': [250, 500],
48+
'alpha': [0.1 , 0.25, 0.5, 1]
49+
}
50+
51+
sk_lr = SKLinear()
52+
our_lr = LinearRegression()
53+
54+
search_sk = RandomizedSearchCV(estimator = sk_lr,
55+
param_distributions = distributions,
56+
random_state = 2022)
57+
58+
search_our = RandomizedSearchCV(estimator = our_lr,
59+
param_distributions = distributions,
60+
random_state = 2022)
61+
62+
search_sk.fit(X, y)
63+
search_our.fit(X, y)
64+
65+
sk_pred = search_sk.predict(X)
66+
our_pred = search_our.predict(X)
67+
68+
assert np.allclose(sk_pred, our_pred)
69+
70+
4171
def test_logistic_regression():
4272
cancer = load_breast_cancer()
4373
X, y = cancer.data, cancer.target
@@ -72,6 +102,35 @@ def test_logistic_regression():
72102
assert global_viz is not None
73103

74104

105+
def test_logistic_regression_sklearn_compatibility():
106+
cancer = load_breast_cancer()
107+
X, y = cancer.data, cancer.target
108+
109+
distributions = {
110+
'penalty': ['l1', 'l2'],
111+
'C': [1 , 0.5, 0.1, 0.05, 0.01]
112+
}
113+
114+
sk_lr = SKLogistic()
115+
our_lr = LogisticRegression()
116+
117+
search_sk = RandomizedSearchCV(estimator = sk_lr,
118+
param_distributions = distributions,
119+
random_state = 2022)
120+
121+
search_our = RandomizedSearchCV(estimator = our_lr,
122+
param_distributions = distributions,
123+
random_state = 2022)
124+
125+
search_sk.fit(X, y)
126+
search_our.fit(X, y)
127+
128+
sk_pred = search_sk.predict_proba(X)
129+
our_pred = search_our.predict_proba(X)
130+
131+
assert np.allclose(sk_pred, our_pred)
132+
133+
75134
def test_sorting():
76135
cancer = load_breast_cancer()
77136
X, y = cancer.data, cancer.target

0 commit comments

Comments
 (0)