77# pyre-strict
88
99import dataclasses
10- from collections import OrderedDict
10+ from contextlib import ExitStack
1111from copy import deepcopy
1212from itertools import product
1313from unittest import mock
2727from ax .models .torch .botorch_modular .utils import (
2828 choose_model_class ,
2929 construct_acquisition_and_optimizer_options ,
30+ fit_botorch_model ,
3031 ModelConfig ,
3132)
32- from ax .models .torch .utils import _filter_X_observed
33+ from ax .models .torch .utils import _filter_X_observed , predict_from_model
3334from ax .models .torch_base import TorchOptConfig
3435from ax .utils .common .constants import Keys
3536from ax .utils .common .testutils import TestCase
@@ -306,6 +307,8 @@ def test__construct(self) -> None:
306307 search_space_digest = self .mf_search_space_digest ,
307308 )
308309
310+ # This mock is hard to remove since it is mocks a method on a surrogate that
311+ # is only constructed when `model.fit` is called
309312 @mock .patch (f"{ SURROGATE_PATH } .Surrogate._construct_model" )
310313 def test_fit (self , mock_fit : Mock ) -> None :
311314 # If surrogate is not yet set, initialize it with dispatcher functions.
@@ -350,11 +353,12 @@ def test_fit(self, mock_fit: Mock) -> None:
350353 refit = True ,
351354 )
352355
353- @ mock . patch ( f" { SURROGATE_PATH } .Surrogate.predict" )
354- def test_predict (self , mock_predict : Mock ) -> None :
355- self .model .predict (X = self .X_test )
356+ def test_predict ( self ) -> None :
357+ with mock . patch . object (self . model . _surrogate , "predict" ) as mock_predict :
358+ self .model .predict (X = self .X_test )
356359 mock_predict .assert_called_with (X = self .X_test , use_posterior_predictive = False )
357- self .model .predict (X = self .X_test , use_posterior_predictive = True )
360+ with mock .patch .object (self .model ._surrogate , "predict" ) as mock_predict :
361+ self .model .predict (X = self .X_test , use_posterior_predictive = True )
358362 mock_predict .assert_called_with (X = self .X_test , use_posterior_predictive = True )
359363
360364 def test_with_surrogate_specs_input (self ) -> None :
@@ -376,48 +380,56 @@ def test_with_surrogate_specs_input(self) -> None:
376380 model = BoTorchModel (surrogate_specs = {"s" : spec1 })
377381 self .assertIs (model .surrogate_spec , spec1 )
378382
379- @mock . patch ( f" { MODEL_PATH } .BoTorchModel.fit" )
380- def test_cross_validate (self , mock_fit : Mock ) -> None :
383+ @mock_botorch_optimize
384+ def test_cross_validate (self ) -> None :
381385 self .model .fit (
382386 datasets = self .block_design_training_data ,
383- search_space_digest = self .mf_search_space_digest ,
387+ search_space_digest = self .search_space_digest ,
384388 candidate_metadata = self .candidate_metadata ,
385389 )
386390
387391 old_surrogate = self .model .surrogate
388- old_surrogate ._model = mock .MagicMock ()
389- old_surrogate ._model .state_dict .return_value = OrderedDict ({"key" : "val" })
390392
391393 for refit_on_cv , warm_start_refit , use_posterior_predictive in product (
392394 (True , False ), (True , False ), (True , False )
393395 ):
394396 self .model .refit_on_cv = refit_on_cv
395397 self .model .warm_start_refit = warm_start_refit
396- with mock .patch (
397- f"{ SURROGATE_PATH } .Surrogate.clone_reset" ,
398- return_value = mock .MagicMock (spec = Surrogate ),
399- ) as mock_clone_reset :
398+ with ExitStack () as es :
399+ mock_fit = es .enter_context (
400+ mock .patch .object (self .model , "fit" , wraps = self .model .fit )
401+ )
402+ mock_predict_orig_surrogate = es .enter_context (
403+ mock .patch .object (
404+ self .model .surrogate ,
405+ "predict" ,
406+ wraps = self .model .surrogate .predict ,
407+ )
408+ )
409+ mock_predict_any_surrogate = es .enter_context (
410+ mock .patch (
411+ f"{ SURROGATE_PATH } .predict_from_model" , wraps = predict_from_model
412+ )
413+ )
400414 self .model .cross_validate (
401415 datasets = self .block_design_training_data ,
402416 X_test = self .X_test ,
403- search_space_digest = self .mf_search_space_digest ,
417+ search_space_digest = self .search_space_digest ,
404418 use_posterior_predictive = use_posterior_predictive ,
405419 )
406- # Check that `predict` is called on the cloned surrogate, not
407- # on the original one.
408- mock_predict = mock_clone_reset .return_value .predict
409- mock_predict .assert_called_once ()
410-
411- # Check correct X_test.
412- kwargs = mock_predict .call_args .kwargs
413- self .assertTrue (torch .equal (kwargs ["X" ], self .X_test ))
414- self .assertIs (
415- kwargs ["use_posterior_predictive" ], use_posterior_predictive
416- )
420+ # Check that `predict` is called on the cloned surrogate, not
421+ # on the original one.
422+ mock_predict_orig_surrogate .assert_not_called ()
423+ mock_predict_any_surrogate .assert_called_once ()
424+
425+ # Check correct X_test.
426+ kwargs = mock_predict_any_surrogate .call_args .kwargs
427+ self .assertTrue (torch .equal (kwargs ["X" ], self .X_test ))
428+ self .assertIs (kwargs ["use_posterior_predictive" ], use_posterior_predictive )
417429
418430 # Check that surrogate is reset back to `old_surrogate` at the
419431 # end of cross-validation.
420- self .assertTrue (self .model .surrogate is old_surrogate )
432+ self .assertIs (self .model .surrogate , old_surrogate )
421433
422434 expected_state_dict = (
423435 None
@@ -435,6 +447,7 @@ def test_cross_validate(self, mock_fit: Mock) -> None:
435447 kwargs ["state_dict" ].keys (), expected_state_dict .keys ()
436448 )
437449
450+ @mock_botorch_optimize
438451 def test_cross_validate_multiple_configs (self ) -> None :
439452 """Test cross-validation with multiple configs."""
440453 for refit_on_cv in (True , False ):
@@ -459,31 +472,31 @@ def test_cross_validate_multiple_configs(self) -> None:
459472 search_space_digest = self .search_space_digest ,
460473 candidate_metadata = self .candidate_metadata ,
461474 )
462- with patch (f"{ Surrogate .__module__ } .fit_botorch_model" ) as mock_fit :
475+ with patch (
476+ f"{ Surrogate .__module__ } .fit_botorch_model" , wraps = fit_botorch_model
477+ ) as mock_fit :
463478 self .model .cross_validate (
464479 datasets = self .block_design_training_data ,
465480 X_test = self .X_test ,
466481 search_space_digest = self .search_space_digest ,
467482 )
468- # check that we don't fit the model during cross_validation
469- if refit_on_cv :
470- mock_fit .assert_called ()
471- else :
472- mock_fit .assert_not_called ()
483+ # check that we don't fit the model during cross_validation
484+ if refit_on_cv :
485+ mock_fit .assert_called ()
486+ else :
487+ mock_fit .assert_not_called ()
473488
474489 @mock_botorch_optimize
475490 @mock .patch (
476491 f"{ MODEL_PATH } .construct_acquisition_and_optimizer_options" ,
477492 wraps = construct_acquisition_and_optimizer_options ,
478493 )
479- @mock .patch (f"{ CURRENT_PATH } .Acquisition.optimize" )
480494 @mock .patch (
481495 f"{ MODEL_PATH } .choose_botorch_acqf_class" , wraps = choose_botorch_acqf_class
482496 )
483497 def _test_gen (
484498 self ,
485499 mock_choose_botorch_acqf_class : Mock ,
486- mock_optimize : Mock ,
487500 mock_construct_options : Mock ,
488501 botorch_model_class : type [Model ],
489502 search_space_digest : SearchSpaceDigest ,
@@ -498,7 +511,7 @@ def _test_gen(
498511 acqf_cls = qLogNoisyExpectedImprovement ,
499512 input_constructor = mock_input_constructor ,
500513 )
501- mock_optimize . return_value = (
514+ mock_optimize_return_value = (
502515 torch .tensor ([[1.0 ]]),
503516 torch .tensor ([2.0 ]),
504517 torch .tensor ([1.0 ]),
@@ -509,25 +522,32 @@ def _test_gen(
509522 acquisition_class = Acquisition ,
510523 acquisition_options = self .acquisition_options ,
511524 )
512- model .surrogate .fit (
513- datasets = self .block_design_training_data ,
514- search_space_digest = search_space_digest ,
515- )
516- model ._botorch_acqf_class = None
517525 # Assert that error is raised if we haven't fit the model
518526 with self .assertRaises (RuntimeError ):
519527 model .gen (
520528 n = 1 ,
521529 search_space_digest = search_space_digest ,
522530 torch_opt_config = self .torch_opt_config ,
523531 )
524- # Add search space digest reference to make the model think it's been fit
525- model ._search_space_digest = search_space_digest
526- with mock .patch .object (
527- BoTorchModel ,
528- "_instantiate_acquisition" ,
529- wraps = model ._instantiate_acquisition ,
530- ) as mock_init_acqf :
532+ model .fit (
533+ datasets = self .block_design_training_data ,
534+ search_space_digest = search_space_digest ,
535+ )
536+ with ExitStack () as es :
537+ mock_init_acqf = es .enter_context (
538+ mock .patch .object (
539+ BoTorchModel ,
540+ "_instantiate_acquisition" ,
541+ wraps = model ._instantiate_acquisition ,
542+ )
543+ )
544+ mock_optimize = es .enter_context (
545+ mock .patch (
546+ f"{ CURRENT_PATH } .Acquisition.optimize" ,
547+ return_value = mock_optimize_return_value ,
548+ )
549+ )
550+
531551 gen_results = model .gen (
532552 n = 1 ,
533553 search_space_digest = search_space_digest ,
@@ -612,6 +632,15 @@ def _test_gen(
612632 input_constructor = qLogNEI_input_constructor ,
613633 )
614634
635+ # Make sure `gen` runs without mocking out Acquisition.optimize
636+ with self .subTest ("No mocks" ):
637+ gen_results = model .gen (
638+ n = 1 ,
639+ search_space_digest = search_space_digest ,
640+ torch_opt_config = self .torch_opt_config ,
641+ )
642+ self .assertTrue (torch .isfinite (gen_results .points ).all ())
643+
615644 def test_gen_SingleTaskGP (self ) -> None :
616645 self ._test_gen (
617646 botorch_model_class = SingleTaskGP ,
@@ -726,14 +755,7 @@ def test_best_point(self) -> None:
726755 )
727756
728757 @mock_botorch_optimize
729- @mock .patch (
730- f"{ MODEL_PATH } .construct_acquisition_and_optimizer_options" ,
731- return_value = ({"num_fantasies" : 64 }, {"num_restarts" : 40 , "raw_samples" : 1024 }),
732- )
733- @mock .patch (f"{ CURRENT_PATH } .Acquisition" , autospec = True )
734- def test_evaluate_acquisition_function (
735- self , mock_acquisition : Mock , _mock_construct_options : Mock
736- ) -> None :
758+ def test_evaluate_acquisition_function (self ) -> None :
737759 model = BoTorchModel (
738760 surrogate = self .surrogate ,
739761 acquisition_class = Acquisition ,
@@ -743,57 +765,42 @@ def test_evaluate_acquisition_function(
743765 datasets = self .block_design_training_data ,
744766 search_space_digest = self .search_space_digest ,
745767 )
746- model .evaluate_acquisition_function (
768+ points = model .evaluate_acquisition_function (
747769 X = self .X_test ,
748770 search_space_digest = self .search_space_digest ,
749771 torch_opt_config = self .torch_opt_config ,
750772 acq_options = self .acquisition_options ,
751773 )
774+ self .assertEqual (points .shape , torch .Size ([1 ]))
752775 # testing that the new setup chooses qLogNEI by default
753776 self .assertEqual (model ._botorch_acqf_class , qLogNoisyExpectedImprovement )
754- # `mock_acquisition` is a mock of the Acquisition class, so to check the mock's
755- # `evaluate` on an instance of that class, we use
756- # `mock_acquisition.return_value.evaluate`.
757- mock_acquisition .return_value .evaluate .assert_called ()
758-
759- @mock .patch (f"{ MODEL_PATH } .Surrogate" , wraps = Surrogate )
760- @mock .patch (f"{ SURROGATE_PATH } .Surrogate._construct_model" , return_value = None )
761- @mock .patch (f"{ SURROGATE_PATH } .ModelListGP" )
762- def test_surrogate_model_options_propagation (
763- self , _m1 : Mock , _m2 : Mock , mock_init : Mock
764- ) -> None :
765- surrogate_spec = SurrogateSpec (
766- botorch_model_kwargs = {"some_option" : "some_value" }
767- )
777+
778+ @mock_botorch_optimize
779+ def test_surrogate_model_options_propagation (self ) -> None :
780+ surrogate_spec = SurrogateSpec ()
768781 model = BoTorchModel (surrogate_spec = surrogate_spec )
769- model .fit (
770- datasets = self .non_block_design_training_data ,
771- search_space_digest = self .mf_search_space_digest ,
772- candidate_metadata = self .candidate_metadata ,
773- )
782+ with mock .patch (f"{ MODEL_PATH } .Surrogate" , wraps = Surrogate ) as mock_init :
783+ model .fit (
784+ datasets = self .non_block_design_training_data ,
785+ search_space_digest = self .mf_search_space_digest ,
786+ candidate_metadata = self .candidate_metadata ,
787+ )
774788 mock_init .assert_called_with (surrogate_spec = surrogate_spec , refit_on_cv = False )
775789
776- @mock .patch (f"{ MODEL_PATH } .Surrogate" , wraps = Surrogate )
777- @mock .patch (f"{ SURROGATE_PATH } .Surrogate._construct_model" , return_value = None )
778- @mock .patch (f"{ SURROGATE_PATH } .ModelListGP" )
779- def test_surrogate_options_propagation (
780- self , _m1 : Mock , _m2 : Mock , mock_init : Mock
781- ) -> None :
790+ @mock_botorch_optimize
791+ def test_surrogate_options_propagation (self ) -> None :
782792 surrogate_spec = SurrogateSpec (allow_batched_models = False )
783793 model = BoTorchModel (surrogate_spec = surrogate_spec )
784- model .fit (
785- datasets = self .non_block_design_training_data ,
786- search_space_digest = self .mf_search_space_digest ,
787- candidate_metadata = self .candidate_metadata ,
788- )
794+ with mock .patch (f"{ MODEL_PATH } .Surrogate" , wraps = Surrogate ) as mock_init :
795+ model .fit (
796+ datasets = self .non_block_design_training_data ,
797+ search_space_digest = self .mf_search_space_digest ,
798+ candidate_metadata = self .candidate_metadata ,
799+ )
789800 mock_init .assert_called_with (surrogate_spec = surrogate_spec , refit_on_cv = False )
790801
791- @mock .patch (
792- f"{ ACQUISITION_PATH } .Acquisition.optimize" ,
793- # Dummy candidates and acquisition function value.
794- return_value = (torch .tensor ([[2.0 ]]), torch .tensor ([1.0 ])),
795- )
796- def test_model_list_choice (self , _ ) -> None : # , mock_extract_training_data):
802+ @mock_botorch_optimize
803+ def test_model_list_choice (self ) -> None :
797804 model = BoTorchModel ()
798805 model .fit (
799806 datasets = self .non_block_design_training_data ,
@@ -807,12 +814,8 @@ def test_model_list_choice(self, _) -> None: # , mock_extract_training_data):
807814 # MFGP should be chosen.
808815 self .assertIsInstance (submodel , SingleTaskMultiFidelityGP )
809816
810- @mock .patch (
811- f"{ ACQUISITION_PATH } .Acquisition.optimize" ,
812- # Dummy candidates, acquisition value, and weights
813- return_value = (torch .tensor ([[2.0 ]]), torch .tensor ([1.0 ]), torch .tensor ([1.0 ])),
814- )
815- def test_MOO (self , _ ) -> None :
817+ @mock_botorch_optimize
818+ def test_MOO (self ) -> None :
816819 # Add mock for qLogNEHVI input constructor to catch arguments passed to it.
817820 qLogNEHVI_input_constructor = get_acqf_input_constructor (
818821 qLogNoisyExpectedHypervolumeImprovement
@@ -843,6 +846,7 @@ def test_MOO(self, _) -> None:
843846 with mock .patch (
844847 f"{ ACQUISITION_PATH } .get_outcome_constraint_transforms" ,
845848 # Dummy candidates and acquisition function value.
849+ # This will return the same value as the original
846850 return_value = constraints ,
847851 ) as mock_get_outcome_constraint_transforms :
848852 gen_results = model .gen (
@@ -940,9 +944,10 @@ def test_MOO(self, _) -> None:
940944 linear_constraints = linear_constraints ,
941945 )
942946
947+ objective_thresholds = torch .tensor ([9.9 , 3.3 , float ("nan" )])
943948 with mock .patch (
944- "ax.models.torch.botorch_modular.acquisition." " infer_objective_thresholds" ,
945- return_value = torch . tensor ([ 9.9 , 3.3 , float ( "nan" )]) ,
949+ "ax.models.torch.botorch_modular.acquisition.infer_objective_thresholds" ,
950+ return_value = objective_thresholds ,
946951 ) as _mock_model_infer_objective_thresholds :
947952 gen_results = model .gen (
948953 n = 1 ,
@@ -973,7 +978,7 @@ def test_MOO(self, _) -> None:
973978 self .assertEqual (m .num_outputs , 2 )
974979 self .assertIn ("objective_thresholds" , gen_results .gen_metadata )
975980 obj_t = gen_results .gen_metadata ["objective_thresholds" ]
976- self .assertTrue (torch .equal (obj_t [:2 ], torch . tensor ([ 9.9 , 3.3 ]) ))
981+ self .assertTrue (torch .equal (obj_t [:2 ], objective_thresholds [: 2 ] ))
977982 self .assertTrue (np .isnan (obj_t [2 ].item ()))
978983
979984 # Avoid polluting the registry for other tests; re-register correct input
0 commit comments