Skip to content

Commit 983a012

Browse files
committed
add end to end tests
1 parent 7efbb94 commit 983a012

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

tests/test_end_to_end.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import numpy as np
2+
import pytest
3+
from sklearn.model_selection import KFold
4+
5+
from autoemulate.compare import AutoEmulate
6+
7+
8+
@pytest.fixture()
9+
def kfold():
10+
return KFold(n_splits=2)
11+
12+
13+
@pytest.fixture()
14+
def Xy_single():
15+
X = np.random.rand(15, 2)
16+
y = np.random.rand(15)
17+
return X, y
18+
19+
20+
@pytest.fixture()
21+
def Xy_multi():
22+
X = np.random.rand(10, 2)
23+
y = np.random.rand(10, 2)
24+
return X, y
25+
26+
27+
@pytest.mark.parametrize("Xy", ["Xy_single", "Xy_multi"])
28+
def test_run(Xy, request):
29+
X, y = request.getfixturevalue(Xy)
30+
em = AutoEmulate()
31+
em.setup(X, y, print_setup=False)
32+
em.compare()
33+
assert em.best_model is not None
34+
assert em.cv_results is not None
35+
36+
37+
def test_run_param_search(Xy_single, kfold):
38+
X, y = Xy_single
39+
em = AutoEmulate()
40+
em.setup(
41+
X,
42+
y,
43+
print_setup=False,
44+
param_search=True,
45+
param_search_iters=2,
46+
cross_validator=kfold,
47+
)
48+
em.compare()
49+
assert em.best_model is not None
50+
assert em.cv_results is not None
51+
52+
53+
def test_run_parallel(Xy_single, kfold):
54+
X, y = Xy_single
55+
em = AutoEmulate()
56+
em.setup(X, y, print_setup=False, cross_validator=kfold, n_jobs=2)
57+
em.compare()
58+
assert em.best_model is not None
59+
assert em.cv_results is not None

0 commit comments

Comments
 (0)