Skip to content

Commit d2ee1de

Browse files
committed
Improvement of optimize_posterior_samples and get_optimal_samples to
improve info-theoretic acquisition functions.
1 parent 78c04e2 commit d2ee1de

File tree

4 files changed

+81
-26
lines changed

4 files changed

+81
-26
lines changed

botorch/acquisition/utils.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ def get_optimal_samples(
534534
posterior_transform: ScalarizedPosteriorTransform | None = None,
535535
objective: MCAcquisitionObjective | None = None,
536536
return_transformed: bool = False,
537+
options: dict | None = None,
537538
) -> tuple[Tensor, Tensor]:
538539
"""Draws sample paths from the posterior and maximizes the samples using GD.
539540
@@ -551,7 +552,8 @@ def get_optimal_samples(
551552
objective: An MCAcquisitionObjective, used to negate the objective or otherwise
552553
transform sample outputs. Cannot be combined with `posterior_transform`.
553554
return_transformed: If True, return the transformed samples.
554-
555+
options: Options for generation of initial candidates, passed to
556+
gen_batch_initial_conditions.
555557
Returns:
556558
The optimal input locations and corresponding outputs, x* and f*.
557559
@@ -576,12 +578,20 @@ def get_optimal_samples(
576578
sample_transform = None
577579

578580
paths = get_matheron_path_model(model=model, sample_shape=torch.Size([num_optima]))
581+
suggested_points = prune_inferior_points(
582+
model=model,
583+
X=model.train_inputs[0],
584+
posterior_transform=posterior_transform,
585+
objective=objective,
586+
)
579587
optimal_inputs, optimal_outputs = optimize_posterior_samples(
580588
paths=paths,
581589
bounds=bounds,
582590
raw_samples=raw_samples,
583591
num_restarts=num_restarts,
584592
sample_transform=sample_transform,
585593
return_transformed=return_transformed,
594+
suggested_points=suggested_points,
595+
options=options,
586596
)
587597
return optimal_inputs, optimal_outputs

botorch/utils/sampling.py

+35-10
Original file line numberDiff line numberDiff line change
@@ -999,10 +999,12 @@ def sparse_to_dense_constraints(
999999
def optimize_posterior_samples(
10001000
paths: GenericDeterministicModel,
10011001
bounds: Tensor,
1002-
raw_samples: int = 1024,
1003-
num_restarts: int = 20,
1002+
raw_samples: int = 2048,
1003+
num_restarts: int = 4,
10041004
sample_transform: Callable[[Tensor], Tensor] | None = None,
10051005
return_transformed: bool = False,
1006+
suggested_points: Tensor | None = None,
1007+
options: dict | None = None,
10061008
) -> tuple[Tensor, Tensor]:
10071009
r"""Cheaply maximizes posterior samples by random querying followed by
10081010
gradient-based optimization using SciPy's L-BFGS-B routine.
@@ -1011,19 +1013,27 @@ def optimize_posterior_samples(
10111013
paths: Random Fourier Feature-based sample paths from the GP
10121014
bounds: The bounds on the search space.
10131015
raw_samples: The number of samples with which to query the samples initially.
1016+
Raw samples are cheap to evaluate, so this should ideally be set much higher
1017+
than num_restarts.
10141018
num_restarts: The number of points selected for gradient-based optimization.
1019+
Should be set low relative to the number of raw samples for time-efficiency.
10151020
sample_transform: A callable transform of the sample outputs (e.g.
10161021
MCAcquisitionObjective or ScalarizedPosteriorTransform.evaluate) used to
10171022
negate the objective or otherwise transform the output.
10181023
return_transformed: A boolean indicating whether to return the transformed
10191024
or non-transformed samples.
1025+
suggested_points: Tensor of suggested input locations that are high-valued.
1026+
These are more densely evaluated during the sampling phase of optimization.
1027+
options: Options for generation of initial candidates, passed to
1028+
gen_batch_initial_conditions.
10201029
10211030
Returns:
10221031
A two-element tuple containing:
10231032
- X_opt: A `num_optima x [batch_size] x d`-dim tensor of optimal inputs x*.
10241033
- f_opt: A `num_optima x [batch_size] x m`-dim, optionally
10251034
`num_optima x [batch_size] x 1`-dim, tensor of optimal outputs f*.
10261035
"""
1036+
options = {} if options is None else options
10271037

10281038
def path_func(x) -> Tensor:
10291039
res = paths(x)
@@ -1032,21 +1042,35 @@ def path_func(x) -> Tensor:
10321042

10331043
return res.squeeze(-1)
10341044

1035-
candidate_set = unnormalize(
1036-
SobolEngine(dimension=bounds.shape[1], scramble=True).draw(n=raw_samples),
1037-
bounds=bounds,
1038-
)
10391045
# queries all samples on all candidates - output shape
10401046
# raw_samples * num_optima * num_models
1047+
frac_random = 1 if suggested_points is None else options.get("frac_random", 0.9)
1048+
candidate_set = draw_sobol_samples(
1049+
bounds=bounds, n=round(raw_samples * frac_random), q=1
1050+
).squeeze(-2)
1051+
if frac_random < 1:
1052+
perturbed_suggestions = sample_truncated_normal_perturbations(
1053+
X=suggested_points,
1054+
n_discrete_points=round(raw_samples * (1 - frac_random)),
1055+
sigma=options.get("sample_around_best_sigma", 1e-2),
1056+
bounds=bounds,
1057+
)
1058+
candidate_set = torch.cat((candidate_set, perturbed_suggestions))
1059+
10411060
candidate_queries = path_func(candidate_set)
1042-
argtop_k = torch.topk(candidate_queries, num_restarts, dim=-1).indices
1043-
X_top_k = candidate_set[argtop_k, :]
1061+
idx = boltzmann_sample(
1062+
function_values=candidate_queries.unsqueeze(-1),
1063+
num_samples=num_restarts,
1064+
eta=options.get("eta", 2.0),
1065+
replacement=False,
1066+
)
1067+
ics = candidate_set[idx, :]
10441068

10451069
# to avoid circular import, the import occurs here
10461070
from botorch.generation.gen import gen_candidates_scipy
10471071

10481072
X_top_k, f_top_k = gen_candidates_scipy(
1049-
X_top_k,
1073+
ics,
10501074
path_func,
10511075
lower_bounds=bounds[0],
10521076
upper_bounds=bounds[1],
@@ -1101,8 +1125,9 @@ def boltzmann_sample(
11011125
eta *= temp_decrease
11021126
weights = torch.exp(eta * norm_weights)
11031127

1128+
# squeeze in case of m = 1 (mono-output provided as batch_size x N x 1)
11041129
return batched_multinomial(
1105-
weights=weights, num_samples=num_samples, replacement=replacement
1130+
weights=weights.squeeze(-1), num_samples=num_samples, replacement=replacement
11061131
)
11071132

11081133

test/acquisition/test_utils.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
UnsupportedError,
3434
)
3535
from botorch.models import SingleTaskGP
36+
from botorch.utils.test_helpers import get_fully_bayesian_model, get_model
3637
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
3738
from gpytorch.distributions import MultivariateNormal
3839

@@ -413,17 +414,14 @@ def test_project_to_sample_points(self):
413414

414415

415416
class TestGetOptimalSamples(BotorchTestCase):
416-
def test_get_optimal_samples(self):
417-
dims = 3
418-
dtype = torch.float64
417+
def _test_get_optimal_samples_base(self, model):
418+
dims = model.train_inputs[0].shape[1]
419+
dtype = model.train_targets.dtype
420+
batch_shape = model.batch_shape
419421
for_testing_speed_kwargs = {"raw_samples": 20, "num_restarts": 2}
420422
num_optima = 7
421-
batch_shape = (3,)
422423

423424
bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T
424-
X = torch.rand(*batch_shape, 4, dims, dtype=dtype)
425-
Y = torch.sin(2 * 3.1415 * X).sum(dim=-1, keepdim=True).to(dtype)
426-
model = SingleTaskGP(train_X=X, train_Y=Y)
427425
posterior_transform = ScalarizedPosteriorTransform(
428426
weights=torch.ones(1, dtype=dtype)
429427
)
@@ -438,6 +436,7 @@ def test_get_optimal_samples(self):
438436
num_optima=num_optima,
439437
**for_testing_speed_kwargs,
440438
)
439+
441440
correct_X_shape = (num_optima,) + batch_shape + (dims,)
442441
correct_f_shape = (num_optima,) + batch_shape + (1,)
443442
self.assertEqual(X_opt_def.shape, correct_X_shape)
@@ -519,6 +518,22 @@ def test_get_optimal_samples(self):
519518
**for_testing_speed_kwargs,
520519
)
521520

521+
def test_optimal_samples(self):
522+
dims = 3
523+
dtype = torch.float64
524+
X = torch.rand(4, dims, dtype=dtype)
525+
Y = torch.sin(2 * 3.1415 * X).sum(dim=-1, keepdim=True).to(dtype)
526+
model = get_model(train_X=X, train_Y=Y)
527+
self._test_get_optimal_samples_base(model)
528+
fully_bayesian_model = get_fully_bayesian_model(
529+
train_X=X,
530+
train_Y=Y,
531+
num_models=4,
532+
standardize_model=True,
533+
infer_noise=True,
534+
)
535+
self._test_get_optimal_samples_base(fully_bayesian_model)
536+
522537

523538
class TestPreferenceUtils(BotorchTestCase):
524539
def test_repeat_to_match_aug_dim(self):

test/utils/test_sampling.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -578,9 +578,13 @@ def test_optimize_posterior_samples(self):
578578
dims = 2
579579
dtype = torch.float64
580580
eps = 1e-4
581-
for_testing_speed_kwargs = {"raw_samples": 128, "num_restarts": 4}
582-
nums_optima = (1, 7)
583-
batch_shapes = ((), (2,), (3, 2))
581+
for_testing_speed_kwargs = {
582+
"raw_samples": 64,
583+
"num_restarts": 2,
584+
"options": {"eta": 10},
585+
}
586+
nums_optima = (1, 5)
587+
batch_shapes = ((), (3,))
584588
posterior_transforms = (
585589
None,
586590
ScalarizedPosteriorTransform(weights=-torch.ones(1, dtype=dtype)),
@@ -589,16 +593,19 @@ def test_optimize_posterior_samples(self):
589593
nums_optima, batch_shapes, posterior_transforms
590594
):
591595
bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T
592-
X = torch.rand(*batch_shape, 4, dims, dtype=dtype)
596+
X = torch.rand(*batch_shape, 3, dims, dtype=dtype)
593597
Y = torch.pow(X - 0.5, 2).sum(dim=-1, keepdim=True)
594598

595599
# having a noiseless model all but guarantees that the found optima
596600
# will be better than the observations
597-
model = SingleTaskGP(X, Y, torch.full_like(Y, eps))
601+
model = SingleTaskGP(
602+
train_X=X, train_Y=Y, train_Yvar=torch.full_like(Y, eps)
603+
)
598604
model.covar_module.lengthscale = 0.5
599605
paths = get_matheron_path_model(
600606
model=model, sample_shape=torch.Size([num_optima])
601607
)
608+
602609
X_opt, f_opt = optimize_posterior_samples(
603610
paths=paths,
604611
bounds=bounds,
@@ -616,8 +623,6 @@ def test_optimize_posterior_samples(self):
616623
self.assertTrue(torch.all(X_opt >= bounds[0]))
617624
self.assertTrue(torch.all(X_opt <= bounds[1]))
618625

619-
# Check that the all found optima are larger than the observations
620-
# This is not 100% deterministic, but just about.
621626
Y_queries = paths(X)
622627
# this is when we negate, so the values should be smaller
623628
if posterior_transform:
@@ -642,7 +647,7 @@ def test_optimize_posterior_samples_multi_objective(self):
642647
dims = 2
643648
dtype = torch.float64
644649
eps = 1e-4
645-
for_testing_speed_kwargs = {"raw_samples": 128, "num_restarts": 4}
650+
for_testing_speed_kwargs = {"raw_samples": 64, "num_restarts": 2}
646651
num_optima = 5
647652
batch_shape = (3,)
648653

0 commit comments

Comments
 (0)