Skip to content

Commit 25506ab

Browse files
71cfacebook-github-bot
authored andcommitted
Make (Log)NoisyExpectedImprovement create a correct fantasy model with non-default SingleTaskGP (#2414)
Summary: ## Motivation In `botorch/acquisition/analytic.py`, the `LogNoisyExpectedImprovement` and `NoisyExpectedImprovement` use the function `_get_noiseless_fantasy_model` in order to repeatedly sample from fantasy model. But `_get_noiseless_fantasy_model` only works for default GP (i.e. with default Matern kernel) & also with no input or outcome transforms. I think that it would make sense if this code were written to work with any kind of `SingleTaskGP`, not just the default one with no input and outcome transforms. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #2414 Test Plan: Since the code is now meant to work even if there are input or outcome transforms or different covar_module or mean_module, I updated the test code to try all these things, as well as try different input bounds to make sure the input transform is working correctly. However, the tests now fail, specifically when either the input data range is not [0,1], or when the kernel is the RBF kernel (not Matern). I believe that the tests failing when RBF is used is simply because certain constants were used in the code that are only valid for particular GP settings. However, I think that when the code fails due to input range not being [0,1], this might be a slight problem with the code -- or might not -- I'm not completely sure. I also added a line that makes sure that the state_dict() is the same between the original model and fantasy model. ## Related PRs I put this in an issue #2412 and was told that it is OK if not all the tests pass. Reviewed By: saitcakmak, esantorella Differential Revision: D59692772 Pulled By: SebastianAment fbshipit-source-id: 25de86f7c06ea924ad578cf319a58803fc907bdb
1 parent b24b3f1 commit 25506ab

File tree

2 files changed

+184
-36
lines changed

2 files changed

+184
-36
lines changed

botorch/acquisition/analytic.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@
1414
import math
1515

1616
from abc import ABC
17-
1817
from contextlib import nullcontext
1918
from copy import deepcopy
20-
2119
from typing import Dict, Optional, Tuple, Union
2220

2321
import torch
@@ -613,15 +611,17 @@ def __init__(
613611
r"""Single-outcome Noisy Log Expected Improvement (via fantasies).
614612
615613
Args:
616-
model: A fitted single-outcome model.
614+
model: A fitted single-outcome model. Only `SingleTaskGP` models with
615+
known observation noise are currently supported.
617616
X_observed: A `n x d` Tensor of observed points that are likely to
618617
be the best observed points so far.
619618
num_fantasies: The number of fantasies to generate. The higher this
620619
number the more accurate the model (at the expense of model
621620
complexity and performance).
622621
maximize: If True, consider the problem a maximization problem.
623622
"""
624-
# sample fantasies
623+
_check_noisy_ei_model(model=model)
624+
# Sample fantasies.
625625
from botorch.sampling.normal import SobolQMCNormalSampler
626626

627627
# Drop gradients from model.posterior if X_observed does not require gradients
@@ -695,16 +695,18 @@ def __init__(
695695
r"""Single-outcome Noisy Expected Improvement (via fantasies).
696696
697697
Args:
698-
model: A fitted single-outcome model.
698+
model: A fitted single-outcome model. Only `SingleTaskGP` models with
699+
known observation noise are currently supported.
699700
X_observed: A `n x d` Tensor of observed points that are likely to
700701
be the best observed points so far.
701702
num_fantasies: The number of fantasies to generate. The higher this
702703
number the more accurate the model (at the expense of model
703704
complexity and performance).
704705
maximize: If True, consider the problem a maximization problem.
705706
"""
707+
_check_noisy_ei_model(model=model)
706708
legacy_ei_numerics_warning(legacy_name=type(self).__name__)
707-
# sample fantasies
709+
# Sample fantasies.
708710
from botorch.sampling.normal import SobolQMCNormalSampler
709711

710712
# Drop gradients from model.posterior if X_observed does not require gradients
@@ -1051,6 +1053,21 @@ def logerfcx(x: Tensor) -> Tensor:
10511053
return torch.log(torch.special.erfcx(a * u) * u.abs()) + b
10521054

10531055

1056+
def _check_noisy_ei_model(model: GPyTorchModel) -> None:
1057+
message = (
1058+
"Only single-output `SingleTaskGP` models with known observation noise "
1059+
"are currently supported for fantasy-based NEI & LogNEI."
1060+
)
1061+
if not isinstance(model, SingleTaskGP):
1062+
raise UnsupportedError(f"{message} Model is not a `SingleTaskGP`.")
1063+
if not isinstance(model.likelihood, FixedNoiseGaussianLikelihood):
1064+
raise UnsupportedError(
1065+
f"{message} Model likelihood is not a `FixedNoiseGaussianLikelihood`."
1066+
)
1067+
if model.num_outputs != 1:
1068+
raise UnsupportedError(f"{message} Model has {model.num_outputs} outputs.")
1069+
1070+
10541071
def _get_noiseless_fantasy_model(
10551072
model: SingleTaskGP, batch_X_observed: Tensor, Y_fantasized: Tensor
10561073
) -> SingleTaskGP:
@@ -1069,31 +1086,49 @@ def _get_noiseless_fantasy_model(
10691086
Returns:
10701087
The fantasy model.
10711088
"""
1072-
if not isinstance(model, SingleTaskGP) or not isinstance(
1073-
model.likelihood, FixedNoiseGaussianLikelihood
1074-
):
1075-
raise UnsupportedError(
1076-
"Only SingleTaskGP models with known observation noise "
1077-
"are currently supported for fantasy-based NEI & LogNEI."
1078-
)
10791089
# initialize a copy of SingleTaskGP on the original training inputs
10801090
# this makes SingleTaskGP a non-batch GP, so that the same hyperparameters
10811091
# are used across all batches (by default, a GP with batched training data
10821092
# uses independent hyperparameters for each batch).
1093+
1094+
# Don't apply `outcome_transform` and `input_transform` here,
1095+
# since the data being passed has already been transformed.
1096+
# So we will instead set them afterwards.
10831097
fantasy_model = SingleTaskGP(
10841098
train_X=model.train_inputs[0],
10851099
train_Y=model.train_targets.unsqueeze(-1),
10861100
train_Yvar=model.likelihood.noise_covar.noise.unsqueeze(-1),
1101+
covar_module=deepcopy(model.covar_module),
1102+
mean_module=deepcopy(model.mean_module),
10871103
)
1104+
1105+
Yvar = torch.full_like(Y_fantasized, 1e-7)
1106+
1107+
# Set the outcome and input transforms of the fantasy model.
1108+
# The transforms should already be in eval mode but just set them to be sure
1109+
outcome_transform = getattr(model, "outcome_transform", None)
1110+
if outcome_transform is not None:
1111+
outcome_transform = deepcopy(outcome_transform).eval()
1112+
fantasy_model.outcome_transform = outcome_transform
1113+
# Need to transform the outcome just as in the SingleTaskGP constructor.
1114+
# Need to unsqueeze for BoTorch and then squeeze again for GPyTorch.
1115+
# Not transforming Yvar because 1e-7 is already close to 0 and it is a
1116+
# relative, not absolute, value.
1117+
Y_fantasized, _ = outcome_transform(
1118+
Y_fantasized.unsqueeze(-1), Yvar.unsqueeze(-1)
1119+
)
1120+
Y_fantasized = Y_fantasized.squeeze(-1)
1121+
input_transform = getattr(model, "input_transform", None)
1122+
if input_transform is not None:
1123+
fantasy_model.input_transform = deepcopy(input_transform).eval()
1124+
10881125
# update training inputs/targets to be batch mode fantasies
10891126
fantasy_model.set_train_data(
10901127
inputs=batch_X_observed, targets=Y_fantasized, strict=False
10911128
)
10921129
# use noiseless fantasies
1093-
fantasy_model.likelihood.noise_covar.noise = torch.full_like(Y_fantasized, 1e-7)
1094-
# load hyperparameters from original model
1095-
state_dict = deepcopy(model.state_dict())
1096-
fantasy_model.load_state_dict(state_dict)
1130+
fantasy_model.likelihood.noise_covar.noise = Yvar
1131+
10971132
return fantasy_model
10981133

10991134

test/acquisition/test_analytic.py

Lines changed: 132 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
78
import math
89
from warnings import catch_warnings, simplefilter
910

1011
import torch
1112
from botorch.acquisition import qAnalyticProbabilityOfImprovement
1213
from botorch.acquisition.analytic import (
14+
_check_noisy_ei_model,
1315
_compute_log_prob_feas,
1416
_ei_helper,
1517
_log_ei_helper,
@@ -33,11 +35,19 @@
3335
)
3436
from botorch.exceptions import UnsupportedError
3537
from botorch.exceptions.warnings import NumericsWarning
36-
from botorch.models import SingleTaskGP
38+
from botorch.models import ModelListGP, SingleTaskGP
39+
from botorch.models.transforms import ChainedOutcomeTransform, Normalize, Standardize
3740
from botorch.posteriors import GPyTorchPosterior
41+
from botorch.sampling.pathwise.utils import get_train_inputs
3842
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
3943
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
40-
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
44+
from gpytorch.kernels import RBFKernel, ScaleKernel
45+
from gpytorch.likelihoods.gaussian_likelihood import (
46+
FixedNoiseGaussianLikelihood,
47+
GaussianLikelihood,
48+
)
49+
from gpytorch.module import Module
50+
from gpytorch.priors.torch_priors import GammaPrior
4151

4252

4353
NEI_NOISE = [
@@ -831,7 +841,15 @@ def _test_constrained_expected_improvement_batch(self, dtype: torch.dtype) -> No
831841

832842

833843
class TestNoisyExpectedImprovement(BotorchTestCase):
834-
def _get_model(self, dtype=torch.float):
844+
def _get_model(
845+
self,
846+
dtype=torch.float,
847+
outcome_transform=None,
848+
input_transform=None,
849+
low_x=0.0,
850+
hi_x=1.0,
851+
covar_module=None,
852+
) -> SingleTaskGP:
835853
state_dict = {
836854
"mean_module.raw_constant": torch.tensor([-0.0066]),
837855
"covar_module.raw_outputscale": torch.tensor(1.0143),
@@ -843,20 +861,31 @@ def _get_model(self, dtype=torch.float):
843861
"covar_module.outputscale_prior.concentration": torch.tensor(2.0),
844862
"covar_module.outputscale_prior.rate": torch.tensor(0.1500),
845863
}
846-
train_x = torch.linspace(0, 1, 10, device=self.device, dtype=dtype).unsqueeze(
847-
-1
848-
)
864+
train_x = torch.linspace(
865+
0.0, 1.0, 10, device=self.device, dtype=dtype
866+
).unsqueeze(-1)
867+
# Taking the sin of the *transformed* input to make the test equivalent
868+
# to when there are no input transforms
849869
train_y = torch.sin(train_x * (2 * math.pi))
870+
# Now transform the input to be passed into SingleTaskGP constructor
871+
train_x = train_x * (hi_x - low_x) + low_x
850872
noise = torch.tensor(NEI_NOISE, device=self.device, dtype=dtype)
851873
train_y += noise
852874
train_yvar = torch.full_like(train_y, 0.25**2)
853-
model = SingleTaskGP(train_X=train_x, train_Y=train_y, train_Yvar=train_yvar)
854-
model.load_state_dict(state_dict)
875+
model = SingleTaskGP(
876+
train_X=train_x,
877+
train_Y=train_y,
878+
train_Yvar=train_yvar,
879+
outcome_transform=outcome_transform,
880+
input_transform=input_transform,
881+
covar_module=covar_module,
882+
)
883+
model.load_state_dict(state_dict, strict=False)
855884
model.to(train_x)
856885
model.eval()
857886
return model
858887

859-
def test_noisy_expected_improvement(self):
888+
def test_noisy_expected_improvement(self) -> None:
860889
model = self._get_model(dtype=torch.float64)
861890
X_observed = model.train_inputs[0]
862891
nfan = 5
@@ -865,14 +894,75 @@ def test_noisy_expected_improvement(self):
865894
):
866895
NoisyExpectedImprovement(model, X_observed, num_fantasies=nfan)
867896

868-
for dtype in (torch.float, torch.double):
897+
# Same as the default Matern kernel
898+
# botorch.models.utils.gpytorch_modules.get_matern_kernel_with_gamma_prior,
899+
# except RBFKernel is used instead of MaternKernel.
900+
# For some reason, RBF gives numerical problems with torch.float but
901+
# Matern does not. Therefore, we'll skip the test for RBF when dtype is
902+
# torch.float.
903+
covar_module_2 = ScaleKernel(
904+
base_kernel=RBFKernel(
905+
ard_num_dims=1,
906+
batch_shape=torch.Size(),
907+
lengthscale_prior=GammaPrior(3.0, 6.0),
908+
),
909+
batch_shape=torch.Size(),
910+
outputscale_prior=GammaPrior(2.0, 0.15),
911+
)
912+
for dtype, use_octf, use_intf, bounds, covar_module in itertools.product(
913+
(torch.float, torch.double),
914+
(False, True),
915+
(False, True),
916+
(torch.tensor([[-3.4], [0.8]]), torch.tensor([[0.0], [1.0]])),
917+
(None, covar_module_2),
918+
):
869919
with catch_warnings():
870920
simplefilter("ignore", category=NumericsWarning)
871-
self._test_noisy_expected_imrpovement(dtype)
921+
self._test_noisy_expected_improvement(
922+
dtype=dtype,
923+
use_octf=use_octf,
924+
use_intf=use_intf,
925+
bounds=bounds,
926+
covar_module=covar_module,
927+
)
928+
929+
def _test_noisy_expected_improvement(
930+
self,
931+
dtype: torch.dtype,
932+
use_octf: bool,
933+
use_intf: bool,
934+
bounds: torch.Tensor,
935+
covar_module: Module,
936+
) -> None:
937+
if covar_module is not None and dtype == torch.float:
938+
# Skip this test because RBF runs into numerical problems with float
939+
# precision
940+
return
941+
octf = (
942+
ChainedOutcomeTransform(standardize=Standardize(m=1)) if use_octf else None
943+
)
944+
intf = (
945+
Normalize(
946+
d=1,
947+
bounds=bounds.to(device=self.device, dtype=dtype),
948+
transform_on_train=True,
949+
)
950+
if use_intf
951+
else None
952+
)
953+
low_x = bounds[0].item() if use_intf else 0.0
954+
hi_x = bounds[1].item() if use_intf else 1.0
955+
model = self._get_model(
956+
dtype=dtype,
957+
outcome_transform=octf,
958+
input_transform=intf,
959+
low_x=low_x,
960+
hi_x=hi_x,
961+
covar_module=covar_module,
962+
)
963+
# Make sure to get the non-transformed training inputs.
964+
X_observed = get_train_inputs(model, transformed=False)[0]
872965

873-
def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
874-
model = self._get_model(dtype=dtype)
875-
X_observed = model.train_inputs[0]
876966
nfan = 5
877967
nEI = NoisyExpectedImprovement(model, X_observed, num_fantasies=nfan)
878968
LogNEI = LogNoisyExpectedImprovement(model, X_observed, num_fantasies=nfan)
@@ -881,6 +971,10 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
881971
self.assertTrue(hasattr(LogNEI, "best_f"))
882972
self.assertIsInstance(LogNEI.model, SingleTaskGP)
883973
self.assertIsInstance(LogNEI.model.likelihood, FixedNoiseGaussianLikelihood)
974+
# Make sure _get_noiseless_fantasy_model gives them
975+
# the same state_dict
976+
self.assertEqual(LogNEI.model.state_dict(), model.state_dict())
977+
884978
LogNEI.model = nEI.model # let the two share their values and fantasies
885979
LogNEI.best_f = nEI.best_f
886980

@@ -892,9 +986,10 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
892986
X_test_log = X_test.clone()
893987
X_test.requires_grad = True
894988
X_test_log.requires_grad = True
895-
val = nEI(X_test)
989+
990+
val = nEI(X_test * (hi_x - low_x) + low_x)
896991
# testing logNEI yields the same result (also checks dtype)
897-
log_val = LogNEI(X_test_log)
992+
log_val = LogNEI(X_test_log * (hi_x - low_x) + low_x)
898993
exp_log_val = log_val.exp()
899994
# notably, val[1] is usually zero in this test, which is precisely what
900995
# gives rise to problems during optimization, and what logNEI avoids
@@ -916,7 +1011,7 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
9161011
# testing gradient through exp of log computation
9171012
exp_log_val.sum().backward()
9181013
# testing that first gradient element coincides. The second is in the
919-
# regime where the naive implementation looses accuracy.
1014+
# regime where the naive implementation loses accuracy.
9201015
atol = 2e-5 if dtype == torch.float32 else 1e-12
9211016
rtol = atol
9221017
self.assertAllClose(X_test.grad[0], X_test_log.grad[0], atol=atol, rtol=rtol)
@@ -945,9 +1040,27 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
9451040
acqf = constructor(model, X_observed, num_fantasies=5)
9461041
self.assertTrue(acqf.best_f.requires_grad)
9471042

1043+
def test_check_noisy_ei_model(self) -> None:
1044+
tkwargs = {"dtype": torch.double, "device": self.device}
1045+
# Multi-output model.
1046+
model = SingleTaskGP(
1047+
train_X=torch.rand(5, 2, **tkwargs),
1048+
train_Y=torch.rand(5, 2, **tkwargs),
1049+
train_Yvar=torch.rand(5, 2, **tkwargs),
1050+
)
1051+
with self.assertRaisesRegex(UnsupportedError, "Model has 2 outputs"):
1052+
_check_noisy_ei_model(model=model)
1053+
# Not SingleTaskGP.
1054+
with self.assertRaisesRegex(UnsupportedError, "Model is not"):
1055+
_check_noisy_ei_model(model=ModelListGP(model))
1056+
# Not fixed noise.
1057+
model.likelihood = GaussianLikelihood()
1058+
with self.assertRaisesRegex(UnsupportedError, "Model likelihood is not"):
1059+
_check_noisy_ei_model(model=model)
1060+
9481061

9491062
class TestScalarizedPosteriorMean(BotorchTestCase):
950-
def test_scalarized_posterior_mean(self):
1063+
def test_scalarized_posterior_mean(self) -> None:
9511064
for dtype in (torch.float, torch.double):
9521065
mean = torch.tensor([[0.25], [0.5]], device=self.device, dtype=dtype)
9531066
mm = MockModel(MockPosterior(mean=mean))
@@ -959,7 +1072,7 @@ def test_scalarized_posterior_mean(self):
9591072
torch.allclose(pm, (mean.squeeze(-1) * module.weights).sum(dim=-1))
9601073
)
9611074

962-
def test_scalarized_posterior_mean_batch(self):
1075+
def test_scalarized_posterior_mean_batch(self) -> None:
9631076
for dtype in (torch.float, torch.double):
9641077
mean = torch.tensor(
9651078
[[-0.5, 1.0], [0.0, 1.0], [0.5, 1.0]], device=self.device, dtype=dtype

0 commit comments

Comments
 (0)