Skip to content

Commit 5be6f17

Browse files
esantorellafacebook-github-bot
authored andcommitted
Remove mocks from ax.models.torch.tests.test_model (#3242)
Summary: Pull Request resolved: #3242 * Updated some mocks to use 'wraps' * Narrowed the scope of some mocks -- it's generally good to scope context managers as narrowly as possible * Used `mock_botorch_optimize` wherever applicable * Removed some mocks entirely * In a case where a mock used 'returns' to test plumbing, also added a subtest with no mock so that we can test that it runs e2e Reviewed By: saitcakmak Differential Revision: D68232004 fbshipit-source-id: 1700a14aab81dfc4f738d6c7d639a8475c457f44
1 parent ca93faa commit 5be6f17

File tree

1 file changed

+108
-103
lines changed

1 file changed

+108
-103
lines changed

ax/models/torch/tests/test_model.py

+108-103
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
import dataclasses
10-
from collections import OrderedDict
10+
from contextlib import ExitStack
1111
from copy import deepcopy
1212
from itertools import product
1313
from unittest import mock
@@ -27,9 +27,10 @@
2727
from ax.models.torch.botorch_modular.utils import (
2828
choose_model_class,
2929
construct_acquisition_and_optimizer_options,
30+
fit_botorch_model,
3031
ModelConfig,
3132
)
32-
from ax.models.torch.utils import _filter_X_observed
33+
from ax.models.torch.utils import _filter_X_observed, predict_from_model
3334
from ax.models.torch_base import TorchOptConfig
3435
from ax.utils.common.constants import Keys
3536
from ax.utils.common.testutils import TestCase
@@ -306,6 +307,8 @@ def test__construct(self) -> None:
306307
search_space_digest=self.mf_search_space_digest,
307308
)
308309

310+
# This mock is hard to remove since it is mocks a method on a surrogate that
311+
# is only constructed when `model.fit` is called
309312
@mock.patch(f"{SURROGATE_PATH}.Surrogate._construct_model")
310313
def test_fit(self, mock_fit: Mock) -> None:
311314
# If surrogate is not yet set, initialize it with dispatcher functions.
@@ -350,11 +353,12 @@ def test_fit(self, mock_fit: Mock) -> None:
350353
refit=True,
351354
)
352355

353-
@mock.patch(f"{SURROGATE_PATH}.Surrogate.predict")
354-
def test_predict(self, mock_predict: Mock) -> None:
355-
self.model.predict(X=self.X_test)
356+
def test_predict(self) -> None:
357+
with mock.patch.object(self.model._surrogate, "predict") as mock_predict:
358+
self.model.predict(X=self.X_test)
356359
mock_predict.assert_called_with(X=self.X_test, use_posterior_predictive=False)
357-
self.model.predict(X=self.X_test, use_posterior_predictive=True)
360+
with mock.patch.object(self.model._surrogate, "predict") as mock_predict:
361+
self.model.predict(X=self.X_test, use_posterior_predictive=True)
358362
mock_predict.assert_called_with(X=self.X_test, use_posterior_predictive=True)
359363

360364
def test_with_surrogate_specs_input(self) -> None:
@@ -376,48 +380,56 @@ def test_with_surrogate_specs_input(self) -> None:
376380
model = BoTorchModel(surrogate_specs={"s": spec1})
377381
self.assertIs(model.surrogate_spec, spec1)
378382

379-
@mock.patch(f"{MODEL_PATH}.BoTorchModel.fit")
380-
def test_cross_validate(self, mock_fit: Mock) -> None:
383+
@mock_botorch_optimize
384+
def test_cross_validate(self) -> None:
381385
self.model.fit(
382386
datasets=self.block_design_training_data,
383-
search_space_digest=self.mf_search_space_digest,
387+
search_space_digest=self.search_space_digest,
384388
candidate_metadata=self.candidate_metadata,
385389
)
386390

387391
old_surrogate = self.model.surrogate
388-
old_surrogate._model = mock.MagicMock()
389-
old_surrogate._model.state_dict.return_value = OrderedDict({"key": "val"})
390392

391393
for refit_on_cv, warm_start_refit, use_posterior_predictive in product(
392394
(True, False), (True, False), (True, False)
393395
):
394396
self.model.refit_on_cv = refit_on_cv
395397
self.model.warm_start_refit = warm_start_refit
396-
with mock.patch(
397-
f"{SURROGATE_PATH}.Surrogate.clone_reset",
398-
return_value=mock.MagicMock(spec=Surrogate),
399-
) as mock_clone_reset:
398+
with ExitStack() as es:
399+
mock_fit = es.enter_context(
400+
mock.patch.object(self.model, "fit", wraps=self.model.fit)
401+
)
402+
mock_predict_orig_surrogate = es.enter_context(
403+
mock.patch.object(
404+
self.model.surrogate,
405+
"predict",
406+
wraps=self.model.surrogate.predict,
407+
)
408+
)
409+
mock_predict_any_surrogate = es.enter_context(
410+
mock.patch(
411+
f"{SURROGATE_PATH}.predict_from_model", wraps=predict_from_model
412+
)
413+
)
400414
self.model.cross_validate(
401415
datasets=self.block_design_training_data,
402416
X_test=self.X_test,
403-
search_space_digest=self.mf_search_space_digest,
417+
search_space_digest=self.search_space_digest,
404418
use_posterior_predictive=use_posterior_predictive,
405419
)
406-
# Check that `predict` is called on the cloned surrogate, not
407-
# on the original one.
408-
mock_predict = mock_clone_reset.return_value.predict
409-
mock_predict.assert_called_once()
410-
411-
# Check correct X_test.
412-
kwargs = mock_predict.call_args.kwargs
413-
self.assertTrue(torch.equal(kwargs["X"], self.X_test))
414-
self.assertIs(
415-
kwargs["use_posterior_predictive"], use_posterior_predictive
416-
)
420+
# Check that `predict` is called on the cloned surrogate, not
421+
# on the original one.
422+
mock_predict_orig_surrogate.assert_not_called()
423+
mock_predict_any_surrogate.assert_called_once()
424+
425+
# Check correct X_test.
426+
kwargs = mock_predict_any_surrogate.call_args.kwargs
427+
self.assertTrue(torch.equal(kwargs["X"], self.X_test))
428+
self.assertIs(kwargs["use_posterior_predictive"], use_posterior_predictive)
417429

418430
# Check that surrogate is reset back to `old_surrogate` at the
419431
# end of cross-validation.
420-
self.assertTrue(self.model.surrogate is old_surrogate)
432+
self.assertIs(self.model.surrogate, old_surrogate)
421433

422434
expected_state_dict = (
423435
None
@@ -435,6 +447,7 @@ def test_cross_validate(self, mock_fit: Mock) -> None:
435447
kwargs["state_dict"].keys(), expected_state_dict.keys()
436448
)
437449

450+
@mock_botorch_optimize
438451
def test_cross_validate_multiple_configs(self) -> None:
439452
"""Test cross-validation with multiple configs."""
440453
for refit_on_cv in (True, False):
@@ -459,31 +472,31 @@ def test_cross_validate_multiple_configs(self) -> None:
459472
search_space_digest=self.search_space_digest,
460473
candidate_metadata=self.candidate_metadata,
461474
)
462-
with patch(f"{Surrogate.__module__}.fit_botorch_model") as mock_fit:
475+
with patch(
476+
f"{Surrogate.__module__}.fit_botorch_model", wraps=fit_botorch_model
477+
) as mock_fit:
463478
self.model.cross_validate(
464479
datasets=self.block_design_training_data,
465480
X_test=self.X_test,
466481
search_space_digest=self.search_space_digest,
467482
)
468-
# check that we don't fit the model during cross_validation
469-
if refit_on_cv:
470-
mock_fit.assert_called()
471-
else:
472-
mock_fit.assert_not_called()
483+
# check that we don't fit the model during cross_validation
484+
if refit_on_cv:
485+
mock_fit.assert_called()
486+
else:
487+
mock_fit.assert_not_called()
473488

474489
@mock_botorch_optimize
475490
@mock.patch(
476491
f"{MODEL_PATH}.construct_acquisition_and_optimizer_options",
477492
wraps=construct_acquisition_and_optimizer_options,
478493
)
479-
@mock.patch(f"{CURRENT_PATH}.Acquisition.optimize")
480494
@mock.patch(
481495
f"{MODEL_PATH}.choose_botorch_acqf_class", wraps=choose_botorch_acqf_class
482496
)
483497
def _test_gen(
484498
self,
485499
mock_choose_botorch_acqf_class: Mock,
486-
mock_optimize: Mock,
487500
mock_construct_options: Mock,
488501
botorch_model_class: type[Model],
489502
search_space_digest: SearchSpaceDigest,
@@ -498,7 +511,7 @@ def _test_gen(
498511
acqf_cls=qLogNoisyExpectedImprovement,
499512
input_constructor=mock_input_constructor,
500513
)
501-
mock_optimize.return_value = (
514+
mock_optimize_return_value = (
502515
torch.tensor([[1.0]]),
503516
torch.tensor([2.0]),
504517
torch.tensor([1.0]),
@@ -509,25 +522,32 @@ def _test_gen(
509522
acquisition_class=Acquisition,
510523
acquisition_options=self.acquisition_options,
511524
)
512-
model.surrogate.fit(
513-
datasets=self.block_design_training_data,
514-
search_space_digest=search_space_digest,
515-
)
516-
model._botorch_acqf_class = None
517525
# Assert that error is raised if we haven't fit the model
518526
with self.assertRaises(RuntimeError):
519527
model.gen(
520528
n=1,
521529
search_space_digest=search_space_digest,
522530
torch_opt_config=self.torch_opt_config,
523531
)
524-
# Add search space digest reference to make the model think it's been fit
525-
model._search_space_digest = search_space_digest
526-
with mock.patch.object(
527-
BoTorchModel,
528-
"_instantiate_acquisition",
529-
wraps=model._instantiate_acquisition,
530-
) as mock_init_acqf:
532+
model.fit(
533+
datasets=self.block_design_training_data,
534+
search_space_digest=search_space_digest,
535+
)
536+
with ExitStack() as es:
537+
mock_init_acqf = es.enter_context(
538+
mock.patch.object(
539+
BoTorchModel,
540+
"_instantiate_acquisition",
541+
wraps=model._instantiate_acquisition,
542+
)
543+
)
544+
mock_optimize = es.enter_context(
545+
mock.patch(
546+
f"{CURRENT_PATH}.Acquisition.optimize",
547+
return_value=mock_optimize_return_value,
548+
)
549+
)
550+
531551
gen_results = model.gen(
532552
n=1,
533553
search_space_digest=search_space_digest,
@@ -612,6 +632,15 @@ def _test_gen(
612632
input_constructor=qLogNEI_input_constructor,
613633
)
614634

635+
# Make sure `gen` runs without mocking out Acquisition.optimize
636+
with self.subTest("No mocks"):
637+
gen_results = model.gen(
638+
n=1,
639+
search_space_digest=search_space_digest,
640+
torch_opt_config=self.torch_opt_config,
641+
)
642+
self.assertTrue(torch.isfinite(gen_results.points).all())
643+
615644
def test_gen_SingleTaskGP(self) -> None:
616645
self._test_gen(
617646
botorch_model_class=SingleTaskGP,
@@ -726,14 +755,7 @@ def test_best_point(self) -> None:
726755
)
727756

728757
@mock_botorch_optimize
729-
@mock.patch(
730-
f"{MODEL_PATH}.construct_acquisition_and_optimizer_options",
731-
return_value=({"num_fantasies": 64}, {"num_restarts": 40, "raw_samples": 1024}),
732-
)
733-
@mock.patch(f"{CURRENT_PATH}.Acquisition", autospec=True)
734-
def test_evaluate_acquisition_function(
735-
self, mock_acquisition: Mock, _mock_construct_options: Mock
736-
) -> None:
758+
def test_evaluate_acquisition_function(self) -> None:
737759
model = BoTorchModel(
738760
surrogate=self.surrogate,
739761
acquisition_class=Acquisition,
@@ -743,57 +765,42 @@ def test_evaluate_acquisition_function(
743765
datasets=self.block_design_training_data,
744766
search_space_digest=self.search_space_digest,
745767
)
746-
model.evaluate_acquisition_function(
768+
points = model.evaluate_acquisition_function(
747769
X=self.X_test,
748770
search_space_digest=self.search_space_digest,
749771
torch_opt_config=self.torch_opt_config,
750772
acq_options=self.acquisition_options,
751773
)
774+
self.assertEqual(points.shape, torch.Size([1]))
752775
# testing that the new setup chooses qLogNEI by default
753776
self.assertEqual(model._botorch_acqf_class, qLogNoisyExpectedImprovement)
754-
# `mock_acquisition` is a mock of the Acquisition class, so to check the mock's
755-
# `evaluate` on an instance of that class, we use
756-
# `mock_acquisition.return_value.evaluate`.
757-
mock_acquisition.return_value.evaluate.assert_called()
758-
759-
@mock.patch(f"{MODEL_PATH}.Surrogate", wraps=Surrogate)
760-
@mock.patch(f"{SURROGATE_PATH}.Surrogate._construct_model", return_value=None)
761-
@mock.patch(f"{SURROGATE_PATH}.ModelListGP")
762-
def test_surrogate_model_options_propagation(
763-
self, _m1: Mock, _m2: Mock, mock_init: Mock
764-
) -> None:
765-
surrogate_spec = SurrogateSpec(
766-
botorch_model_kwargs={"some_option": "some_value"}
767-
)
777+
778+
@mock_botorch_optimize
779+
def test_surrogate_model_options_propagation(self) -> None:
780+
surrogate_spec = SurrogateSpec()
768781
model = BoTorchModel(surrogate_spec=surrogate_spec)
769-
model.fit(
770-
datasets=self.non_block_design_training_data,
771-
search_space_digest=self.mf_search_space_digest,
772-
candidate_metadata=self.candidate_metadata,
773-
)
782+
with mock.patch(f"{MODEL_PATH}.Surrogate", wraps=Surrogate) as mock_init:
783+
model.fit(
784+
datasets=self.non_block_design_training_data,
785+
search_space_digest=self.mf_search_space_digest,
786+
candidate_metadata=self.candidate_metadata,
787+
)
774788
mock_init.assert_called_with(surrogate_spec=surrogate_spec, refit_on_cv=False)
775789

776-
@mock.patch(f"{MODEL_PATH}.Surrogate", wraps=Surrogate)
777-
@mock.patch(f"{SURROGATE_PATH}.Surrogate._construct_model", return_value=None)
778-
@mock.patch(f"{SURROGATE_PATH}.ModelListGP")
779-
def test_surrogate_options_propagation(
780-
self, _m1: Mock, _m2: Mock, mock_init: Mock
781-
) -> None:
790+
@mock_botorch_optimize
791+
def test_surrogate_options_propagation(self) -> None:
782792
surrogate_spec = SurrogateSpec(allow_batched_models=False)
783793
model = BoTorchModel(surrogate_spec=surrogate_spec)
784-
model.fit(
785-
datasets=self.non_block_design_training_data,
786-
search_space_digest=self.mf_search_space_digest,
787-
candidate_metadata=self.candidate_metadata,
788-
)
794+
with mock.patch(f"{MODEL_PATH}.Surrogate", wraps=Surrogate) as mock_init:
795+
model.fit(
796+
datasets=self.non_block_design_training_data,
797+
search_space_digest=self.mf_search_space_digest,
798+
candidate_metadata=self.candidate_metadata,
799+
)
789800
mock_init.assert_called_with(surrogate_spec=surrogate_spec, refit_on_cv=False)
790801

791-
@mock.patch(
792-
f"{ACQUISITION_PATH}.Acquisition.optimize",
793-
# Dummy candidates and acquisition function value.
794-
return_value=(torch.tensor([[2.0]]), torch.tensor([1.0])),
795-
)
796-
def test_model_list_choice(self, _) -> None: # , mock_extract_training_data):
802+
@mock_botorch_optimize
803+
def test_model_list_choice(self) -> None:
797804
model = BoTorchModel()
798805
model.fit(
799806
datasets=self.non_block_design_training_data,
@@ -807,12 +814,8 @@ def test_model_list_choice(self, _) -> None: # , mock_extract_training_data):
807814
# MFGP should be chosen.
808815
self.assertIsInstance(submodel, SingleTaskMultiFidelityGP)
809816

810-
@mock.patch(
811-
f"{ACQUISITION_PATH}.Acquisition.optimize",
812-
# Dummy candidates, acquisition value, and weights
813-
return_value=(torch.tensor([[2.0]]), torch.tensor([1.0]), torch.tensor([1.0])),
814-
)
815-
def test_MOO(self, _) -> None:
817+
@mock_botorch_optimize
818+
def test_MOO(self) -> None:
816819
# Add mock for qLogNEHVI input constructor to catch arguments passed to it.
817820
qLogNEHVI_input_constructor = get_acqf_input_constructor(
818821
qLogNoisyExpectedHypervolumeImprovement
@@ -843,6 +846,7 @@ def test_MOO(self, _) -> None:
843846
with mock.patch(
844847
f"{ACQUISITION_PATH}.get_outcome_constraint_transforms",
845848
# Dummy candidates and acquisition function value.
849+
# This will return the same value as the original
846850
return_value=constraints,
847851
) as mock_get_outcome_constraint_transforms:
848852
gen_results = model.gen(
@@ -940,9 +944,10 @@ def test_MOO(self, _) -> None:
940944
linear_constraints=linear_constraints,
941945
)
942946

947+
objective_thresholds = torch.tensor([9.9, 3.3, float("nan")])
943948
with mock.patch(
944-
"ax.models.torch.botorch_modular.acquisition." "infer_objective_thresholds",
945-
return_value=torch.tensor([9.9, 3.3, float("nan")]),
949+
"ax.models.torch.botorch_modular.acquisition.infer_objective_thresholds",
950+
return_value=objective_thresholds,
946951
) as _mock_model_infer_objective_thresholds:
947952
gen_results = model.gen(
948953
n=1,
@@ -973,7 +978,7 @@ def test_MOO(self, _) -> None:
973978
self.assertEqual(m.num_outputs, 2)
974979
self.assertIn("objective_thresholds", gen_results.gen_metadata)
975980
obj_t = gen_results.gen_metadata["objective_thresholds"]
976-
self.assertTrue(torch.equal(obj_t[:2], torch.tensor([9.9, 3.3])))
981+
self.assertTrue(torch.equal(obj_t[:2], objective_thresholds[:2]))
977982
self.assertTrue(np.isnan(obj_t[2].item()))
978983

979984
# Avoid polluting the registry for other tests; re-register correct input

0 commit comments

Comments
 (0)