4
4
# This source code is licensed under the MIT license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import itertools
7
8
import math
8
9
from warnings import catch_warnings , simplefilter
9
10
10
11
import torch
11
12
from botorch .acquisition import qAnalyticProbabilityOfImprovement
12
13
from botorch .acquisition .analytic import (
14
+ _check_noisy_ei_model ,
13
15
_compute_log_prob_feas ,
14
16
_ei_helper ,
15
17
_log_ei_helper ,
33
35
)
34
36
from botorch .exceptions import UnsupportedError
35
37
from botorch .exceptions .warnings import NumericsWarning
36
- from botorch .models import SingleTaskGP
38
+ from botorch .models import ModelListGP , SingleTaskGP
39
+ from botorch .models .transforms import ChainedOutcomeTransform , Normalize , Standardize
37
40
from botorch .posteriors import GPyTorchPosterior
41
+ from botorch .sampling .pathwise .utils import get_train_inputs
38
42
from botorch .utils .testing import BotorchTestCase , MockModel , MockPosterior
39
43
from gpytorch .distributions import MultitaskMultivariateNormal , MultivariateNormal
40
- from gpytorch .likelihoods .gaussian_likelihood import FixedNoiseGaussianLikelihood
44
+ from gpytorch .kernels import RBFKernel , ScaleKernel
45
+ from gpytorch .likelihoods .gaussian_likelihood import (
46
+ FixedNoiseGaussianLikelihood ,
47
+ GaussianLikelihood ,
48
+ )
49
+ from gpytorch .module import Module
50
+ from gpytorch .priors .torch_priors import GammaPrior
41
51
42
52
43
53
NEI_NOISE = [
@@ -831,7 +841,15 @@ def _test_constrained_expected_improvement_batch(self, dtype: torch.dtype) -> No
831
841
832
842
833
843
class TestNoisyExpectedImprovement (BotorchTestCase ):
834
- def _get_model (self , dtype = torch .float ):
844
+ def _get_model (
845
+ self ,
846
+ dtype = torch .float ,
847
+ outcome_transform = None ,
848
+ input_transform = None ,
849
+ low_x = 0.0 ,
850
+ hi_x = 1.0 ,
851
+ covar_module = None ,
852
+ ) -> SingleTaskGP :
835
853
state_dict = {
836
854
"mean_module.raw_constant" : torch .tensor ([- 0.0066 ]),
837
855
"covar_module.raw_outputscale" : torch .tensor (1.0143 ),
@@ -843,20 +861,31 @@ def _get_model(self, dtype=torch.float):
843
861
"covar_module.outputscale_prior.concentration" : torch .tensor (2.0 ),
844
862
"covar_module.outputscale_prior.rate" : torch .tensor (0.1500 ),
845
863
}
846
- train_x = torch .linspace (0 , 1 , 10 , device = self .device , dtype = dtype ).unsqueeze (
847
- - 1
848
- )
864
+ train_x = torch .linspace (
865
+ 0.0 , 1.0 , 10 , device = self .device , dtype = dtype
866
+ ).unsqueeze (- 1 )
867
+ # Taking the sin of the *transformed* input to make the test equivalent
868
+ # to when there are no input transforms
849
869
train_y = torch .sin (train_x * (2 * math .pi ))
870
+ # Now transform the input to be passed into SingleTaskGP constructor
871
+ train_x = train_x * (hi_x - low_x ) + low_x
850
872
noise = torch .tensor (NEI_NOISE , device = self .device , dtype = dtype )
851
873
train_y += noise
852
874
train_yvar = torch .full_like (train_y , 0.25 ** 2 )
853
- model = SingleTaskGP (train_X = train_x , train_Y = train_y , train_Yvar = train_yvar )
854
- model .load_state_dict (state_dict )
875
+ model = SingleTaskGP (
876
+ train_X = train_x ,
877
+ train_Y = train_y ,
878
+ train_Yvar = train_yvar ,
879
+ outcome_transform = outcome_transform ,
880
+ input_transform = input_transform ,
881
+ covar_module = covar_module ,
882
+ )
883
+ model .load_state_dict (state_dict , strict = False )
855
884
model .to (train_x )
856
885
model .eval ()
857
886
return model
858
887
859
- def test_noisy_expected_improvement (self ):
888
+ def test_noisy_expected_improvement (self ) -> None :
860
889
model = self ._get_model (dtype = torch .float64 )
861
890
X_observed = model .train_inputs [0 ]
862
891
nfan = 5
@@ -865,14 +894,75 @@ def test_noisy_expected_improvement(self):
865
894
):
866
895
NoisyExpectedImprovement (model , X_observed , num_fantasies = nfan )
867
896
868
- for dtype in (torch .float , torch .double ):
897
+ # Same as the default Matern kernel
898
+ # botorch.models.utils.gpytorch_modules.get_matern_kernel_with_gamma_prior,
899
+ # except RBFKernel is used instead of MaternKernel.
900
+ # For some reason, RBF gives numerical problems with torch.float but
901
+ # Matern does not. Therefore, we'll skip the test for RBF when dtype is
902
+ # torch.float.
903
+ covar_module_2 = ScaleKernel (
904
+ base_kernel = RBFKernel (
905
+ ard_num_dims = 1 ,
906
+ batch_shape = torch .Size (),
907
+ lengthscale_prior = GammaPrior (3.0 , 6.0 ),
908
+ ),
909
+ batch_shape = torch .Size (),
910
+ outputscale_prior = GammaPrior (2.0 , 0.15 ),
911
+ )
912
+ for dtype , use_octf , use_intf , bounds , covar_module in itertools .product (
913
+ (torch .float , torch .double ),
914
+ (False , True ),
915
+ (False , True ),
916
+ (torch .tensor ([[- 3.4 ], [0.8 ]]), torch .tensor ([[0.0 ], [1.0 ]])),
917
+ (None , covar_module_2 ),
918
+ ):
869
919
with catch_warnings ():
870
920
simplefilter ("ignore" , category = NumericsWarning )
871
- self ._test_noisy_expected_imrpovement (dtype )
921
+ self ._test_noisy_expected_improvement (
922
+ dtype = dtype ,
923
+ use_octf = use_octf ,
924
+ use_intf = use_intf ,
925
+ bounds = bounds ,
926
+ covar_module = covar_module ,
927
+ )
928
+
929
+ def _test_noisy_expected_improvement (
930
+ self ,
931
+ dtype : torch .dtype ,
932
+ use_octf : bool ,
933
+ use_intf : bool ,
934
+ bounds : torch .Tensor ,
935
+ covar_module : Module ,
936
+ ) -> None :
937
+ if covar_module is not None and dtype == torch .float :
938
+ # Skip this test because RBF runs into numerical problems with float
939
+ # precision
940
+ return
941
+ octf = (
942
+ ChainedOutcomeTransform (standardize = Standardize (m = 1 )) if use_octf else None
943
+ )
944
+ intf = (
945
+ Normalize (
946
+ d = 1 ,
947
+ bounds = bounds .to (device = self .device , dtype = dtype ),
948
+ transform_on_train = True ,
949
+ )
950
+ if use_intf
951
+ else None
952
+ )
953
+ low_x = bounds [0 ].item () if use_intf else 0.0
954
+ hi_x = bounds [1 ].item () if use_intf else 1.0
955
+ model = self ._get_model (
956
+ dtype = dtype ,
957
+ outcome_transform = octf ,
958
+ input_transform = intf ,
959
+ low_x = low_x ,
960
+ hi_x = hi_x ,
961
+ covar_module = covar_module ,
962
+ )
963
+ # Make sure to get the non-transformed training inputs.
964
+ X_observed = get_train_inputs (model , transformed = False )[0 ]
872
965
873
- def _test_noisy_expected_imrpovement (self , dtype : torch .dtype ) -> None :
874
- model = self ._get_model (dtype = dtype )
875
- X_observed = model .train_inputs [0 ]
876
966
nfan = 5
877
967
nEI = NoisyExpectedImprovement (model , X_observed , num_fantasies = nfan )
878
968
LogNEI = LogNoisyExpectedImprovement (model , X_observed , num_fantasies = nfan )
@@ -881,6 +971,10 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
881
971
self .assertTrue (hasattr (LogNEI , "best_f" ))
882
972
self .assertIsInstance (LogNEI .model , SingleTaskGP )
883
973
self .assertIsInstance (LogNEI .model .likelihood , FixedNoiseGaussianLikelihood )
974
+ # Make sure _get_noiseless_fantasy_model gives them
975
+ # the same state_dict
976
+ self .assertEqual (LogNEI .model .state_dict (), model .state_dict ())
977
+
884
978
LogNEI .model = nEI .model # let the two share their values and fantasies
885
979
LogNEI .best_f = nEI .best_f
886
980
@@ -892,9 +986,10 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
892
986
X_test_log = X_test .clone ()
893
987
X_test .requires_grad = True
894
988
X_test_log .requires_grad = True
895
- val = nEI (X_test )
989
+
990
+ val = nEI (X_test * (hi_x - low_x ) + low_x )
896
991
# testing logNEI yields the same result (also checks dtype)
897
- log_val = LogNEI (X_test_log )
992
+ log_val = LogNEI (X_test_log * ( hi_x - low_x ) + low_x )
898
993
exp_log_val = log_val .exp ()
899
994
# notably, val[1] is usually zero in this test, which is precisely what
900
995
# gives rise to problems during optimization, and what logNEI avoids
@@ -916,7 +1011,7 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
916
1011
# testing gradient through exp of log computation
917
1012
exp_log_val .sum ().backward ()
918
1013
# testing that first gradient element coincides. The second is in the
919
- # regime where the naive implementation looses accuracy.
1014
+ # regime where the naive implementation loses accuracy.
920
1015
atol = 2e-5 if dtype == torch .float32 else 1e-12
921
1016
rtol = atol
922
1017
self .assertAllClose (X_test .grad [0 ], X_test_log .grad [0 ], atol = atol , rtol = rtol )
@@ -945,9 +1040,27 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
945
1040
acqf = constructor (model , X_observed , num_fantasies = 5 )
946
1041
self .assertTrue (acqf .best_f .requires_grad )
947
1042
1043
+ def test_check_noisy_ei_model (self ) -> None :
1044
+ tkwargs = {"dtype" : torch .double , "device" : self .device }
1045
+ # Multi-output model.
1046
+ model = SingleTaskGP (
1047
+ train_X = torch .rand (5 , 2 , ** tkwargs ),
1048
+ train_Y = torch .rand (5 , 2 , ** tkwargs ),
1049
+ train_Yvar = torch .rand (5 , 2 , ** tkwargs ),
1050
+ )
1051
+ with self .assertRaisesRegex (UnsupportedError , "Model has 2 outputs" ):
1052
+ _check_noisy_ei_model (model = model )
1053
+ # Not SingleTaskGP.
1054
+ with self .assertRaisesRegex (UnsupportedError , "Model is not" ):
1055
+ _check_noisy_ei_model (model = ModelListGP (model ))
1056
+ # Not fixed noise.
1057
+ model .likelihood = GaussianLikelihood ()
1058
+ with self .assertRaisesRegex (UnsupportedError , "Model likelihood is not" ):
1059
+ _check_noisy_ei_model (model = model )
1060
+
948
1061
949
1062
class TestScalarizedPosteriorMean (BotorchTestCase ):
950
- def test_scalarized_posterior_mean (self ):
1063
+ def test_scalarized_posterior_mean (self ) -> None :
951
1064
for dtype in (torch .float , torch .double ):
952
1065
mean = torch .tensor ([[0.25 ], [0.5 ]], device = self .device , dtype = dtype )
953
1066
mm = MockModel (MockPosterior (mean = mean ))
@@ -959,7 +1072,7 @@ def test_scalarized_posterior_mean(self):
959
1072
torch .allclose (pm , (mean .squeeze (- 1 ) * module .weights ).sum (dim = - 1 ))
960
1073
)
961
1074
962
- def test_scalarized_posterior_mean_batch (self ):
1075
+ def test_scalarized_posterior_mean_batch (self ) -> None :
963
1076
for dtype in (torch .float , torch .double ):
964
1077
mean = torch .tensor (
965
1078
[[- 0.5 , 1.0 ], [0.0 , 1.0 ], [0.5 , 1.0 ]], device = self .device , dtype = dtype
0 commit comments