Skip to content

Commit ad98646

Browse files
authored
Merge pull request #460 from jakobrunge/developer
regressionCI fixes
2 parents 8b84dce + c5e5ff0 commit ad98646

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

tigramite/independence_tests/regressionCI.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
from scipy.stats import chi2, normaltest
1010
from sklearn.linear_model import LinearRegression, LogisticRegression
1111
from sklearn import metrics
12+
from sklearn.dummy import DummyClassifier
1213

13-
from .independence_tests_base import CondIndTest
14+
import tigramite
15+
from tigramite.independence_tests.independence_tests_base import CondIndTest
1416

1517

1618
class RegressionCI(CondIndTest):
@@ -146,12 +148,19 @@ def calc_deviance_logistic(X, y, var_type):
146148
# 1-hot-encode all categorical columns
147149
X = do_componentwise_one_hot_encoding(X, var_type=var_type)
148150
y = np.ravel(y).astype('int')
149-
# do logistic regression
150-
model = LogisticRegression(solver='lbfgs')
151-
model.fit(X, y)
152-
deviance = 2*metrics.log_loss(y, model.predict_proba(X), normalize=False)
151+
# do logistic regression, if y only contains one class, return zero.
152+
if len(np.unique(y)) < 2:
153+
model = DummyClassifier(strategy="constant", constant=y[0])
154+
model.fit(X, y)
155+
deviance = 0.
156+
else:
157+
model = LogisticRegression(solver='lbfgs')
158+
model.fit(X, y)
159+
deviance = 2 * metrics.log_loss(y, model.predict_proba(X), normalize=False)
160+
153161
# dofs: +2 for intercept (+1) (not too important, cancels out later anyway)
154162
dof = model.n_features_in_ + 1
163+
155164
return deviance, dof
156165

157166
def calc_deviance_linear(X, y, var_type):
@@ -285,6 +294,15 @@ def get_analytic_significance(self, value, T, dim, xyz):
285294

286295
if __name__ == '__main__':
287296

297+
298+
import pandas as pd
299+
import numpy as np
300+
import matplotlib.pyplot as plt
301+
from tigramite.pcmci import PCMCI
302+
# from tigramite.independence_tests.regressionCI import RegressionCI
303+
import tigramite.plotting as tp
304+
import tigramite.data_processing as pp
305+
288306
import tigramite
289307
from tigramite.data_processing import DataFrame
290308
import tigramite.data_processing as pp
@@ -378,3 +396,4 @@ def get_analytic_significance(self, value, T, dim, xyz):
378396
print((rate <= 0.05).mean())
379397

380398

399+
# dummy

0 commit comments

Comments
 (0)