|
5 | 5 | from sklearn.datasets import load_breast_cancer, load_boston
|
6 | 6 | from sklearn.linear_model import LogisticRegression as SKLogistic
|
7 | 7 | from sklearn.linear_model import Lasso as SKLinear
|
| 8 | +from sklearn.model_selection import RandomizedSearchCV |
8 | 9 | import numpy as np
|
9 | 10 |
|
10 | 11 |
|
@@ -38,6 +39,35 @@ def test_linear_regression():
|
38 | 39 | assert global_viz is not None
|
39 | 40 |
|
40 | 41 |
|
| 42 | +def test_linear_regression_sklearn_compatibility(): |
| 43 | + boston = load_boston() |
| 44 | + X, y = boston.data, boston.target |
| 45 | + |
| 46 | + distributions = { |
| 47 | + 'max_iter': [250, 500], |
| 48 | + 'alpha': [0.1 , 0.25, 0.5, 1] |
| 49 | + } |
| 50 | + |
| 51 | + sk_lr = SKLinear() |
| 52 | + our_lr = LinearRegression() |
| 53 | + |
| 54 | + search_sk = RandomizedSearchCV(estimator = sk_lr, |
| 55 | + param_distributions = distributions, |
| 56 | + random_state = 2022) |
| 57 | + |
| 58 | + search_our = RandomizedSearchCV(estimator = our_lr, |
| 59 | + param_distributions = distributions, |
| 60 | + random_state = 2022) |
| 61 | + |
| 62 | + search_sk.fit(X, y) |
| 63 | + search_our.fit(X, y) |
| 64 | + |
| 65 | + sk_pred = search_sk.predict(X) |
| 66 | + our_pred = search_our.predict(X) |
| 67 | + |
| 68 | + assert np.allclose(sk_pred, our_pred) |
| 69 | + |
| 70 | + |
41 | 71 | def test_logistic_regression():
|
42 | 72 | cancer = load_breast_cancer()
|
43 | 73 | X, y = cancer.data, cancer.target
|
@@ -72,6 +102,35 @@ def test_logistic_regression():
|
72 | 102 | assert global_viz is not None
|
73 | 103 |
|
74 | 104 |
|
| 105 | +def test_logistic_regression_sklearn_compatibility(): |
| 106 | + cancer = load_breast_cancer() |
| 107 | + X, y = cancer.data, cancer.target |
| 108 | + |
| 109 | + distributions = { |
| 110 | + 'penalty': ['l1', 'l2'], |
| 111 | + 'C': [1 , 0.5, 0.1, 0.05, 0.01] |
| 112 | + } |
| 113 | + |
| 114 | + sk_lr = SKLogistic() |
| 115 | + our_lr = LogisticRegression() |
| 116 | + |
| 117 | + search_sk = RandomizedSearchCV(estimator = sk_lr, |
| 118 | + param_distributions = distributions, |
| 119 | + random_state = 2022) |
| 120 | + |
| 121 | + search_our = RandomizedSearchCV(estimator = our_lr, |
| 122 | + param_distributions = distributions, |
| 123 | + random_state = 2022) |
| 124 | + |
| 125 | + search_sk.fit(X, y) |
| 126 | + search_our.fit(X, y) |
| 127 | + |
| 128 | + sk_pred = search_sk.predict_proba(X) |
| 129 | + our_pred = search_our.predict_proba(X) |
| 130 | + |
| 131 | + assert np.allclose(sk_pred, our_pred) |
| 132 | + |
| 133 | + |
75 | 134 | def test_sorting():
|
76 | 135 | cancer = load_breast_cancer()
|
77 | 136 | X, y = cancer.data, cancer.target
|
|
0 commit comments