|
9 | 9 | from scipy.stats import chi2, normaltest |
10 | 10 | from sklearn.linear_model import LinearRegression, LogisticRegression |
11 | 11 | from sklearn import metrics |
| 12 | +from sklearn.dummy import DummyClassifier |
12 | 13 |
|
13 | | -from .independence_tests_base import CondIndTest |
| 14 | +import tigramite |
| 15 | +from tigramite.independence_tests.independence_tests_base import CondIndTest |
14 | 16 |
|
15 | 17 |
|
16 | 18 | class RegressionCI(CondIndTest): |
@@ -146,12 +148,19 @@ def calc_deviance_logistic(X, y, var_type): |
146 | 148 | # 1-hot-encode all categorical columns |
147 | 149 | X = do_componentwise_one_hot_encoding(X, var_type=var_type) |
148 | 150 | y = np.ravel(y).astype('int') |
149 | | - # do logistic regression |
150 | | - model = LogisticRegression(solver='lbfgs') |
151 | | - model.fit(X, y) |
152 | | - deviance = 2*metrics.log_loss(y, model.predict_proba(X), normalize=False) |
| 151 | + # do logistic regression, if y only contains one class, return zero. |
| 152 | + if len(np.unique(y)) < 2: |
| 153 | + model = DummyClassifier(strategy="constant", constant=y[0]) |
| 154 | + model.fit(X, y) |
| 155 | + deviance = 0. |
| 156 | + else: |
| 157 | + model = LogisticRegression(solver='lbfgs') |
| 158 | + model.fit(X, y) |
| 159 | + deviance = 2 * metrics.log_loss(y, model.predict_proba(X), normalize=False) |
| 160 | + |
153 | 161 | # dofs: +2 for intercept (+1) (not too important, cancels out later anyway) |
154 | 162 | dof = model.n_features_in_ + 1 |
| 163 | + |
155 | 164 | return deviance, dof |
156 | 165 |
|
157 | 166 | def calc_deviance_linear(X, y, var_type): |
@@ -285,6 +294,15 @@ def get_analytic_significance(self, value, T, dim, xyz): |
285 | 294 |
|
286 | 295 | if __name__ == '__main__': |
287 | 296 |
|
| 297 | + |
| 298 | + import pandas as pd |
| 299 | + import numpy as np |
| 300 | + import matplotlib.pyplot as plt |
| 301 | + from tigramite.pcmci import PCMCI |
| 302 | + # from tigramite.independence_tests.regressionCI import RegressionCI |
| 303 | + import tigramite.plotting as tp |
| 304 | + import tigramite.data_processing as pp |
| 305 | + |
288 | 306 | import tigramite |
289 | 307 | from tigramite.data_processing import DataFrame |
290 | 308 | import tigramite.data_processing as pp |
@@ -378,3 +396,4 @@ def get_analytic_significance(self, value, T, dim, xyz): |
378 | 396 | print((rate <= 0.05).mean()) |
379 | 397 |
|
380 | 398 |
|
| 399 | +# dummy |
0 commit comments