Skip to content

Commit 5a8e902

Browse files
Autoemulate.compare: proposed fix for issue #222 (#224)
* Autoemulate.compare: proposed fix for issue #222 * Update autoemulate/compare.py comment re indentation Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update autoemulate/compare.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Autoemulate.compare: response to comment - removed unnecessary block. * tests/test_compare.py: added test to check hyperparameters are correctly updated * Update tests/test_compare.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * tests: applying automatic suggestion * test passing pre-commit locally --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 2f1d439 commit 5a8e902

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

autoemulate/compare.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def compare(self):
230230
X=self.X[self.train_idxs],
231231
y=self.y[self.train_idxs],
232232
cv=self.cross_validator,
233-
model=model,
233+
model=self.models[i],
234234
metrics=self.metrics,
235235
n_jobs=self.n_jobs,
236236
logger=self.logger,

tests/test_compare.py

+21
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,24 @@ def test_refit_models(ae_run):
162162
models = ae_run.refit_models()
163163
assert models is not None
164164
assert len(models) == len(ae_run.models)
165+
166+
167+
# --------------- test correct hyperparameter updating ------------------
168+
def test_param_search_updates_models(ae, Xy):
169+
X, y = Xy
170+
ae.setup(
171+
X, y, model_subset=["RandomForest"], param_search=True, param_search_iters=5
172+
)
173+
params_before = ae.models[0].get_params() # just one model, so index with 0
174+
ae.compare()
175+
params_after = ae.models[0].get_params()
176+
assert params_before != params_after
177+
178+
179+
def test_model_params_equal_wo_param_search(ae, Xy):
180+
X, y = Xy
181+
ae.setup(X, y, model_subset=["RandomForest"])
182+
params_before = ae.models[0].get_params()
183+
ae.compare()
184+
params_after = ae.models[0].get_params()
185+
assert params_before == params_after

0 commit comments

Comments
 (0)