Skip to content

Commit 5cc7a0f

Browse files
authored
Merge pull request #265 from alan-turing-institute/cv-fold-standardisation
same cv splits across models
2 parents 3daa3ed + 421b56e commit 5cc7a0f

File tree

4 files changed

+36
-3
lines changed

4 files changed

+36
-3
lines changed

autoemulate/compare.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from autoemulate.save import ModelSerialiser
2323
from autoemulate.sensitivity_analysis import plot_sensitivity_analysis
2424
from autoemulate.sensitivity_analysis import sensitivity_analysis
25+
from autoemulate.utils import _check_cv
2526
from autoemulate.utils import _ensure_2d
2627
from autoemulate.utils import _get_full_model_name
2728
from autoemulate.utils import _redirect_warnings
@@ -54,7 +55,9 @@ def setup(
5455
scaler=StandardScaler(),
5556
reduce_dim=False,
5657
dim_reducer=PCA(),
57-
cross_validator=KFold(n_splits=5, shuffle=True),
58+
cross_validator=KFold(
59+
n_splits=5, shuffle=True, random_state=np.random.randint(1e5)
60+
),
5861
n_jobs=None,
5962
models=None,
6063
verbose=0,
@@ -121,7 +124,7 @@ def setup(
121124
dim_reducer=dim_reducer,
122125
)
123126
self.metrics = self._get_metrics(METRIC_REGISTRY)
124-
self.cross_validator = cross_validator
127+
self.cross_validator = _check_cv(cross_validator)
125128
self.param_search = param_search
126129
self.search_type = param_search_type
127130
self.param_search_iters = param_search_iters

autoemulate/utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from sklearn.base import RegressorMixin
1111
from sklearn.exceptions import ConvergenceWarning
12+
from sklearn.model_selection import KFold
1213
from sklearn.multioutput import MultiOutputRegressor
1314
from sklearn.pipeline import Pipeline
1415

@@ -370,3 +371,18 @@ def _ensure_2d(arr):
370371
if arr.ndim == 1:
371372
arr = arr.reshape(-1, 1)
372373
return arr
374+
375+
376+
# checkers for scikit-learn objects --------------------------------------------
377+
378+
379+
def _check_cv(cv):
380+
"""Ensure that cross-validation method is valid"""
381+
if cv is None:
382+
raise ValueError("cross_validator cannot be None")
383+
if not isinstance(cv, KFold):
384+
raise ValueError(
385+
"cross_validator should be an instance of KFold cross-validation. We do not "
386+
"currently support other cross-validation methods."
387+
)
388+
return cv

tests/test_ui.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_cross_validators():
4444
X = np.random.rand(100, 5)
4545
y = np.random.rand(100, 1)
4646

47-
cross_validators = [KFold(n_splits=5), TimeSeriesSplit(n_splits=5)]
47+
cross_validators = [KFold(n_splits=5)]
4848

4949
for cross_validator in cross_validators:
5050
ae = AutoEmulate()

tests/test_utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import pytest
44
from sklearn.ensemble import GradientBoostingRegressor
55
from sklearn.ensemble import RandomForestRegressor
6+
from sklearn.model_selection import KFold
7+
from sklearn.model_selection import LeaveOneOut
68
from sklearn.multioutput import MultiOutputRegressor
79
from sklearn.pipeline import Pipeline
810
from sklearn.preprocessing import StandardScaler
@@ -12,6 +14,7 @@
1214
from autoemulate.utils import _add_prefix_to_param_space
1315
from autoemulate.utils import _add_prefix_to_single_grid
1416
from autoemulate.utils import _adjust_param_space
17+
from autoemulate.utils import _check_cv
1518
from autoemulate.utils import _denormalise_y
1619
from autoemulate.utils import _ensure_2d
1720
from autoemulate.utils import _get_full_model_name
@@ -340,3 +343,14 @@ def test_ensure_2d_2d():
340343
y = np.array([[1, 2], [3, 4], [5, 6]])
341344
y_2d = _ensure_2d(y)
342345
assert y_2d.ndim == 2
346+
347+
348+
# test checkers for scikit-learn objects --------------------------------------
349+
def test_check_cv():
350+
cv = KFold(n_splits=5, shuffle=True, random_state=np.random.randint(1e5))
351+
_check_cv(cv)
352+
353+
354+
def test_check_cv_error():
355+
with pytest.raises(ValueError):
356+
_check_cv(LeaveOneOut())

0 commit comments

Comments
 (0)