7
7
# pyre-strict
8
8
9
9
import dataclasses
10
- from collections import OrderedDict
10
+ from contextlib import ExitStack
11
11
from copy import deepcopy
12
12
from itertools import product
13
13
from unittest import mock
27
27
from ax .models .torch .botorch_modular .utils import (
28
28
choose_model_class ,
29
29
construct_acquisition_and_optimizer_options ,
30
+ fit_botorch_model ,
30
31
ModelConfig ,
31
32
)
32
- from ax .models .torch .utils import _filter_X_observed
33
+ from ax .models .torch .utils import _filter_X_observed , predict_from_model
33
34
from ax .models .torch_base import TorchOptConfig
34
35
from ax .utils .common .constants import Keys
35
36
from ax .utils .common .testutils import TestCase
@@ -306,6 +307,8 @@ def test__construct(self) -> None:
306
307
search_space_digest = self .mf_search_space_digest ,
307
308
)
308
309
310
+ # This mock is hard to remove since it is mocks a method on a surrogate that
311
+ # is only constructed when `model.fit` is called
309
312
@mock .patch (f"{ SURROGATE_PATH } .Surrogate._construct_model" )
310
313
def test_fit (self , mock_fit : Mock ) -> None :
311
314
# If surrogate is not yet set, initialize it with dispatcher functions.
@@ -350,11 +353,12 @@ def test_fit(self, mock_fit: Mock) -> None:
350
353
refit = True ,
351
354
)
352
355
353
- @ mock . patch ( f" { SURROGATE_PATH } .Surrogate.predict" )
354
- def test_predict (self , mock_predict : Mock ) -> None :
355
- self .model .predict (X = self .X_test )
356
+ def test_predict ( self ) -> None :
357
+ with mock . patch . object (self . model . _surrogate , "predict" ) as mock_predict :
358
+ self .model .predict (X = self .X_test )
356
359
mock_predict .assert_called_with (X = self .X_test , use_posterior_predictive = False )
357
- self .model .predict (X = self .X_test , use_posterior_predictive = True )
360
+ with mock .patch .object (self .model ._surrogate , "predict" ) as mock_predict :
361
+ self .model .predict (X = self .X_test , use_posterior_predictive = True )
358
362
mock_predict .assert_called_with (X = self .X_test , use_posterior_predictive = True )
359
363
360
364
def test_with_surrogate_specs_input (self ) -> None :
@@ -376,48 +380,56 @@ def test_with_surrogate_specs_input(self) -> None:
376
380
model = BoTorchModel (surrogate_specs = {"s" : spec1 })
377
381
self .assertIs (model .surrogate_spec , spec1 )
378
382
379
- @mock . patch ( f" { MODEL_PATH } .BoTorchModel.fit" )
380
- def test_cross_validate (self , mock_fit : Mock ) -> None :
383
+ @mock_botorch_optimize
384
+ def test_cross_validate (self ) -> None :
381
385
self .model .fit (
382
386
datasets = self .block_design_training_data ,
383
- search_space_digest = self .mf_search_space_digest ,
387
+ search_space_digest = self .search_space_digest ,
384
388
candidate_metadata = self .candidate_metadata ,
385
389
)
386
390
387
391
old_surrogate = self .model .surrogate
388
- old_surrogate ._model = mock .MagicMock ()
389
- old_surrogate ._model .state_dict .return_value = OrderedDict ({"key" : "val" })
390
392
391
393
for refit_on_cv , warm_start_refit , use_posterior_predictive in product (
392
394
(True , False ), (True , False ), (True , False )
393
395
):
394
396
self .model .refit_on_cv = refit_on_cv
395
397
self .model .warm_start_refit = warm_start_refit
396
- with mock .patch (
397
- f"{ SURROGATE_PATH } .Surrogate.clone_reset" ,
398
- return_value = mock .MagicMock (spec = Surrogate ),
399
- ) as mock_clone_reset :
398
+ with ExitStack () as es :
399
+ mock_fit = es .enter_context (
400
+ mock .patch .object (self .model , "fit" , wraps = self .model .fit )
401
+ )
402
+ mock_predict_orig_surrogate = es .enter_context (
403
+ mock .patch .object (
404
+ self .model .surrogate ,
405
+ "predict" ,
406
+ wraps = self .model .surrogate .predict ,
407
+ )
408
+ )
409
+ mock_predict_any_surrogate = es .enter_context (
410
+ mock .patch (
411
+ f"{ SURROGATE_PATH } .predict_from_model" , wraps = predict_from_model
412
+ )
413
+ )
400
414
self .model .cross_validate (
401
415
datasets = self .block_design_training_data ,
402
416
X_test = self .X_test ,
403
- search_space_digest = self .mf_search_space_digest ,
417
+ search_space_digest = self .search_space_digest ,
404
418
use_posterior_predictive = use_posterior_predictive ,
405
419
)
406
- # Check that `predict` is called on the cloned surrogate, not
407
- # on the original one.
408
- mock_predict = mock_clone_reset .return_value .predict
409
- mock_predict .assert_called_once ()
410
-
411
- # Check correct X_test.
412
- kwargs = mock_predict .call_args .kwargs
413
- self .assertTrue (torch .equal (kwargs ["X" ], self .X_test ))
414
- self .assertIs (
415
- kwargs ["use_posterior_predictive" ], use_posterior_predictive
416
- )
420
+ # Check that `predict` is called on the cloned surrogate, not
421
+ # on the original one.
422
+ mock_predict_orig_surrogate .assert_not_called ()
423
+ mock_predict_any_surrogate .assert_called_once ()
424
+
425
+ # Check correct X_test.
426
+ kwargs = mock_predict_any_surrogate .call_args .kwargs
427
+ self .assertTrue (torch .equal (kwargs ["X" ], self .X_test ))
428
+ self .assertIs (kwargs ["use_posterior_predictive" ], use_posterior_predictive )
417
429
418
430
# Check that surrogate is reset back to `old_surrogate` at the
419
431
# end of cross-validation.
420
- self .assertTrue (self .model .surrogate is old_surrogate )
432
+ self .assertIs (self .model .surrogate , old_surrogate )
421
433
422
434
expected_state_dict = (
423
435
None
@@ -435,6 +447,7 @@ def test_cross_validate(self, mock_fit: Mock) -> None:
435
447
kwargs ["state_dict" ].keys (), expected_state_dict .keys ()
436
448
)
437
449
450
+ @mock_botorch_optimize
438
451
def test_cross_validate_multiple_configs (self ) -> None :
439
452
"""Test cross-validation with multiple configs."""
440
453
for refit_on_cv in (True , False ):
@@ -459,31 +472,31 @@ def test_cross_validate_multiple_configs(self) -> None:
459
472
search_space_digest = self .search_space_digest ,
460
473
candidate_metadata = self .candidate_metadata ,
461
474
)
462
- with patch (f"{ Surrogate .__module__ } .fit_botorch_model" ) as mock_fit :
475
+ with patch (
476
+ f"{ Surrogate .__module__ } .fit_botorch_model" , wraps = fit_botorch_model
477
+ ) as mock_fit :
463
478
self .model .cross_validate (
464
479
datasets = self .block_design_training_data ,
465
480
X_test = self .X_test ,
466
481
search_space_digest = self .search_space_digest ,
467
482
)
468
- # check that we don't fit the model during cross_validation
469
- if refit_on_cv :
470
- mock_fit .assert_called ()
471
- else :
472
- mock_fit .assert_not_called ()
483
+ # check that we don't fit the model during cross_validation
484
+ if refit_on_cv :
485
+ mock_fit .assert_called ()
486
+ else :
487
+ mock_fit .assert_not_called ()
473
488
474
489
@mock_botorch_optimize
475
490
@mock .patch (
476
491
f"{ MODEL_PATH } .construct_acquisition_and_optimizer_options" ,
477
492
wraps = construct_acquisition_and_optimizer_options ,
478
493
)
479
- @mock .patch (f"{ CURRENT_PATH } .Acquisition.optimize" )
480
494
@mock .patch (
481
495
f"{ MODEL_PATH } .choose_botorch_acqf_class" , wraps = choose_botorch_acqf_class
482
496
)
483
497
def _test_gen (
484
498
self ,
485
499
mock_choose_botorch_acqf_class : Mock ,
486
- mock_optimize : Mock ,
487
500
mock_construct_options : Mock ,
488
501
botorch_model_class : type [Model ],
489
502
search_space_digest : SearchSpaceDigest ,
@@ -498,7 +511,7 @@ def _test_gen(
498
511
acqf_cls = qLogNoisyExpectedImprovement ,
499
512
input_constructor = mock_input_constructor ,
500
513
)
501
- mock_optimize . return_value = (
514
+ mock_optimize_return_value = (
502
515
torch .tensor ([[1.0 ]]),
503
516
torch .tensor ([2.0 ]),
504
517
torch .tensor ([1.0 ]),
@@ -509,25 +522,32 @@ def _test_gen(
509
522
acquisition_class = Acquisition ,
510
523
acquisition_options = self .acquisition_options ,
511
524
)
512
- model .surrogate .fit (
513
- datasets = self .block_design_training_data ,
514
- search_space_digest = search_space_digest ,
515
- )
516
- model ._botorch_acqf_class = None
517
525
# Assert that error is raised if we haven't fit the model
518
526
with self .assertRaises (RuntimeError ):
519
527
model .gen (
520
528
n = 1 ,
521
529
search_space_digest = search_space_digest ,
522
530
torch_opt_config = self .torch_opt_config ,
523
531
)
524
- # Add search space digest reference to make the model think it's been fit
525
- model ._search_space_digest = search_space_digest
526
- with mock .patch .object (
527
- BoTorchModel ,
528
- "_instantiate_acquisition" ,
529
- wraps = model ._instantiate_acquisition ,
530
- ) as mock_init_acqf :
532
+ model .fit (
533
+ datasets = self .block_design_training_data ,
534
+ search_space_digest = search_space_digest ,
535
+ )
536
+ with ExitStack () as es :
537
+ mock_init_acqf = es .enter_context (
538
+ mock .patch .object (
539
+ BoTorchModel ,
540
+ "_instantiate_acquisition" ,
541
+ wraps = model ._instantiate_acquisition ,
542
+ )
543
+ )
544
+ mock_optimize = es .enter_context (
545
+ mock .patch (
546
+ f"{ CURRENT_PATH } .Acquisition.optimize" ,
547
+ return_value = mock_optimize_return_value ,
548
+ )
549
+ )
550
+
531
551
gen_results = model .gen (
532
552
n = 1 ,
533
553
search_space_digest = search_space_digest ,
@@ -612,6 +632,15 @@ def _test_gen(
612
632
input_constructor = qLogNEI_input_constructor ,
613
633
)
614
634
635
+ # Make sure `gen` runs without mocking out Acquisition.optimize
636
+ with self .subTest ("No mocks" ):
637
+ gen_results = model .gen (
638
+ n = 1 ,
639
+ search_space_digest = search_space_digest ,
640
+ torch_opt_config = self .torch_opt_config ,
641
+ )
642
+ self .assertTrue (torch .isfinite (gen_results .points ).all ())
643
+
615
644
def test_gen_SingleTaskGP (self ) -> None :
616
645
self ._test_gen (
617
646
botorch_model_class = SingleTaskGP ,
@@ -726,14 +755,7 @@ def test_best_point(self) -> None:
726
755
)
727
756
728
757
@mock_botorch_optimize
729
- @mock .patch (
730
- f"{ MODEL_PATH } .construct_acquisition_and_optimizer_options" ,
731
- return_value = ({"num_fantasies" : 64 }, {"num_restarts" : 40 , "raw_samples" : 1024 }),
732
- )
733
- @mock .patch (f"{ CURRENT_PATH } .Acquisition" , autospec = True )
734
- def test_evaluate_acquisition_function (
735
- self , mock_acquisition : Mock , _mock_construct_options : Mock
736
- ) -> None :
758
+ def test_evaluate_acquisition_function (self ) -> None :
737
759
model = BoTorchModel (
738
760
surrogate = self .surrogate ,
739
761
acquisition_class = Acquisition ,
@@ -743,57 +765,42 @@ def test_evaluate_acquisition_function(
743
765
datasets = self .block_design_training_data ,
744
766
search_space_digest = self .search_space_digest ,
745
767
)
746
- model .evaluate_acquisition_function (
768
+ points = model .evaluate_acquisition_function (
747
769
X = self .X_test ,
748
770
search_space_digest = self .search_space_digest ,
749
771
torch_opt_config = self .torch_opt_config ,
750
772
acq_options = self .acquisition_options ,
751
773
)
774
+ self .assertEqual (points .shape , torch .Size ([1 ]))
752
775
# testing that the new setup chooses qLogNEI by default
753
776
self .assertEqual (model ._botorch_acqf_class , qLogNoisyExpectedImprovement )
754
- # `mock_acquisition` is a mock of the Acquisition class, so to check the mock's
755
- # `evaluate` on an instance of that class, we use
756
- # `mock_acquisition.return_value.evaluate`.
757
- mock_acquisition .return_value .evaluate .assert_called ()
758
-
759
- @mock .patch (f"{ MODEL_PATH } .Surrogate" , wraps = Surrogate )
760
- @mock .patch (f"{ SURROGATE_PATH } .Surrogate._construct_model" , return_value = None )
761
- @mock .patch (f"{ SURROGATE_PATH } .ModelListGP" )
762
- def test_surrogate_model_options_propagation (
763
- self , _m1 : Mock , _m2 : Mock , mock_init : Mock
764
- ) -> None :
765
- surrogate_spec = SurrogateSpec (
766
- botorch_model_kwargs = {"some_option" : "some_value" }
767
- )
777
+
778
+ @mock_botorch_optimize
779
+ def test_surrogate_model_options_propagation (self ) -> None :
780
+ surrogate_spec = SurrogateSpec ()
768
781
model = BoTorchModel (surrogate_spec = surrogate_spec )
769
- model .fit (
770
- datasets = self .non_block_design_training_data ,
771
- search_space_digest = self .mf_search_space_digest ,
772
- candidate_metadata = self .candidate_metadata ,
773
- )
782
+ with mock .patch (f"{ MODEL_PATH } .Surrogate" , wraps = Surrogate ) as mock_init :
783
+ model .fit (
784
+ datasets = self .non_block_design_training_data ,
785
+ search_space_digest = self .mf_search_space_digest ,
786
+ candidate_metadata = self .candidate_metadata ,
787
+ )
774
788
mock_init .assert_called_with (surrogate_spec = surrogate_spec , refit_on_cv = False )
775
789
776
- @mock .patch (f"{ MODEL_PATH } .Surrogate" , wraps = Surrogate )
777
- @mock .patch (f"{ SURROGATE_PATH } .Surrogate._construct_model" , return_value = None )
778
- @mock .patch (f"{ SURROGATE_PATH } .ModelListGP" )
779
- def test_surrogate_options_propagation (
780
- self , _m1 : Mock , _m2 : Mock , mock_init : Mock
781
- ) -> None :
790
+ @mock_botorch_optimize
791
+ def test_surrogate_options_propagation (self ) -> None :
782
792
surrogate_spec = SurrogateSpec (allow_batched_models = False )
783
793
model = BoTorchModel (surrogate_spec = surrogate_spec )
784
- model .fit (
785
- datasets = self .non_block_design_training_data ,
786
- search_space_digest = self .mf_search_space_digest ,
787
- candidate_metadata = self .candidate_metadata ,
788
- )
794
+ with mock .patch (f"{ MODEL_PATH } .Surrogate" , wraps = Surrogate ) as mock_init :
795
+ model .fit (
796
+ datasets = self .non_block_design_training_data ,
797
+ search_space_digest = self .mf_search_space_digest ,
798
+ candidate_metadata = self .candidate_metadata ,
799
+ )
789
800
mock_init .assert_called_with (surrogate_spec = surrogate_spec , refit_on_cv = False )
790
801
791
- @mock .patch (
792
- f"{ ACQUISITION_PATH } .Acquisition.optimize" ,
793
- # Dummy candidates and acquisition function value.
794
- return_value = (torch .tensor ([[2.0 ]]), torch .tensor ([1.0 ])),
795
- )
796
- def test_model_list_choice (self , _ ) -> None : # , mock_extract_training_data):
802
+ @mock_botorch_optimize
803
+ def test_model_list_choice (self ) -> None :
797
804
model = BoTorchModel ()
798
805
model .fit (
799
806
datasets = self .non_block_design_training_data ,
@@ -807,12 +814,8 @@ def test_model_list_choice(self, _) -> None: # , mock_extract_training_data):
807
814
# MFGP should be chosen.
808
815
self .assertIsInstance (submodel , SingleTaskMultiFidelityGP )
809
816
810
- @mock .patch (
811
- f"{ ACQUISITION_PATH } .Acquisition.optimize" ,
812
- # Dummy candidates, acquisition value, and weights
813
- return_value = (torch .tensor ([[2.0 ]]), torch .tensor ([1.0 ]), torch .tensor ([1.0 ])),
814
- )
815
- def test_MOO (self , _ ) -> None :
817
+ @mock_botorch_optimize
818
+ def test_MOO (self ) -> None :
816
819
# Add mock for qLogNEHVI input constructor to catch arguments passed to it.
817
820
qLogNEHVI_input_constructor = get_acqf_input_constructor (
818
821
qLogNoisyExpectedHypervolumeImprovement
@@ -843,6 +846,7 @@ def test_MOO(self, _) -> None:
843
846
with mock .patch (
844
847
f"{ ACQUISITION_PATH } .get_outcome_constraint_transforms" ,
845
848
# Dummy candidates and acquisition function value.
849
+ # This will return the same value as the original
846
850
return_value = constraints ,
847
851
) as mock_get_outcome_constraint_transforms :
848
852
gen_results = model .gen (
@@ -940,9 +944,10 @@ def test_MOO(self, _) -> None:
940
944
linear_constraints = linear_constraints ,
941
945
)
942
946
947
+ objective_thresholds = torch .tensor ([9.9 , 3.3 , float ("nan" )])
943
948
with mock .patch (
944
- "ax.models.torch.botorch_modular.acquisition." " infer_objective_thresholds" ,
945
- return_value = torch . tensor ([ 9.9 , 3.3 , float ( "nan" )]) ,
949
+ "ax.models.torch.botorch_modular.acquisition.infer_objective_thresholds" ,
950
+ return_value = objective_thresholds ,
946
951
) as _mock_model_infer_objective_thresholds :
947
952
gen_results = model .gen (
948
953
n = 1 ,
@@ -973,7 +978,7 @@ def test_MOO(self, _) -> None:
973
978
self .assertEqual (m .num_outputs , 2 )
974
979
self .assertIn ("objective_thresholds" , gen_results .gen_metadata )
975
980
obj_t = gen_results .gen_metadata ["objective_thresholds" ]
976
- self .assertTrue (torch .equal (obj_t [:2 ], torch . tensor ([ 9.9 , 3.3 ]) ))
981
+ self .assertTrue (torch .equal (obj_t [:2 ], objective_thresholds [: 2 ] ))
977
982
self .assertTrue (np .isnan (obj_t [2 ].item ()))
978
983
979
984
# Avoid polluting the registry for other tests; re-register correct input
0 commit comments