Skip to content

Commit 242981f

Browse files
jasmainakpavanramkumar
authored andcommitted
FIX: simulate_glm in test_metrics
1 parent c3521e9 commit 242981f

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

tests/test_metrics.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def test_deviance():
88
"""Test deviance."""
99
n_samples, n_features = 1000, 100
1010

11-
beta0 = np.random.normal(0.0, 1.0, 1)
11+
beta0 = np.random.rand()
1212
beta = np.random.normal(0.0, 1.0, n_features)
1313

1414
# sample train and test data
@@ -26,7 +26,7 @@ def test_pseudoR2():
2626
"""Test pseudo r2."""
2727
n_samples, n_features = 1000, 100
2828

29-
beta0 = np.random.normal(0.0, 1.0, 1)
29+
beta0 = np.random.rand()
3030
beta = np.random.normal(0.0, 1.0, n_features)
3131

3232
# sample train and test data
@@ -44,13 +44,15 @@ def test_accuracy():
4444
"""Testing accuracy."""
4545
n_samples, n_features, n_classes = 1000, 100, 2
4646

47-
beta0 = np.random.normal(0.0, 1.0, 1)
48-
beta = np.random.normal(0.0, 1.0, (n_features, n_classes))
47+
beta0 = np.random.rand()
48+
betas = np.random.normal(0.0, 1.0, (n_features, n_classes))
4949

5050
# sample train and test data
5151
glm_sim = GLM(distr='binomial', score_metric='accuracy')
5252
X = np.random.randn(n_samples, n_features)
53-
y = simulate_glm(glm_sim.distr, beta0, beta, X)
53+
y = np.zeros((n_samples, 2))
54+
for idx, beta in enumerate(betas.T):
55+
y[:, idx] = simulate_glm(glm_sim.distr, beta0, beta, X)
5456
y = np.argmax(y, axis=1)
5557
glm_sim.fit(X, y)
5658
score = glm_sim.score(X, y)

0 commit comments

Comments
 (0)