Skip to content

Commit 520aad7

Browse files
committed
Added ES optimization initializer
1 parent ae56adf commit 520aad7

File tree

3 files changed

+286
-106
lines changed

3 files changed

+286
-106
lines changed

botorch/optim/initializers.py

+194-106
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from botorch.acquisition import analytic, monte_carlo, multi_objective
2525
from botorch.acquisition.acquisition import AcquisitionFunction
2626
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
27+
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
2728
from botorch.acquisition.knowledge_gradient import (
2829
_get_value_function,
2930
qKnowledgeGradient,
@@ -471,6 +472,90 @@ def gen_batch_initial_conditions(
471472
return batch_initial_conditions
472473

473474

475+
def gen_optimal_input_initial_conditions(
476+
acq_function: AcquisitionFunction,
477+
bounds: Tensor,
478+
q: int,
479+
num_restarts: int,
480+
raw_samples: int,
481+
fixed_features: dict[int, float] | None = None,
482+
options: dict[str, bool | float | int] | None = None,
483+
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
484+
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
485+
):
486+
device = bounds.device
487+
if not hasattr(acq_function, "optimal_inputs"):
488+
raise AttributeError(
489+
"gen_optimal_input_initial_conditions can only be used with "
490+
"an AcquisitionFunction that has an optimal_inputs attribute."
491+
)
492+
frac_random: float = options.get("frac_random", 0.0)
493+
if not 0 <= frac_random <= 1:
494+
raise ValueError(
495+
f"frac_random must take on values in (0,1). Value: {frac_random}"
496+
)
497+
498+
batch_limit = options.get("batch_limit")
499+
num_optima = acq_function.optimal_inputs.shape[:-1].numel()
500+
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1)
501+
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype)
502+
num_random = round(raw_samples * frac_random)
503+
if num_random > 0:
504+
X_rnd = sample_q_batches_from_polytope(
505+
n=num_random,
506+
q=q,
507+
bounds=bounds,
508+
n_burnin=options.get("n_burnin", 10000),
509+
n_thinning=options.get("n_thinning", 32),
510+
equality_constraints=equality_constraints,
511+
inequality_constraints=inequality_constraints,
512+
)
513+
X = torch.cat((X, X_rnd))
514+
515+
if num_random < raw_samples:
516+
X_perturbed = sample_points_around_best(
517+
acq_function=acq_function,
518+
n_discrete_points=q * (raw_samples - num_random),
519+
sigma=options.get("sample_around_best_sigma", 1e-2),
520+
bounds=bounds,
521+
best_X=suggestions,
522+
)
523+
X_perturbed = X_perturbed.view(
524+
raw_samples - num_random, q, bounds.shape[-1]
525+
).cpu()
526+
X = torch.cat((X, X_perturbed))
527+
528+
if options.get("sample_around_best", False):
529+
X_best = sample_points_around_best(
530+
acq_function=acq_function,
531+
n_discrete_points=q * raw_samples,
532+
sigma=options.get("sample_around_best_sigma", 1e-3),
533+
bounds=bounds,
534+
)
535+
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu()
536+
X = torch.cat((X, X_best))
537+
538+
with torch.no_grad():
539+
if batch_limit is None:
540+
batch_limit = X.shape[0]
541+
# Evaluate the acquisition function on `X_rnd` using `batch_limit`
542+
# sized chunks.
543+
acq_vals = torch.cat(
544+
[
545+
acq_function(x_.to(device=device)).cpu()
546+
for x_ in X.split(split_size=batch_limit, dim=0)
547+
],
548+
dim=0,
549+
)
550+
idx = boltzmann_sample(
551+
function_values=acq_vals,
552+
num_samples=num_restarts,
553+
eta=options.get("eta", 2.0),
554+
)
555+
# set the respective initial conditions to the sampled optimizers
556+
return X[idx]
557+
558+
474559
def gen_one_shot_kg_initial_conditions(
475560
acq_function: qKnowledgeGradient,
476561
bounds: Tensor,
@@ -605,59 +690,59 @@ def gen_one_shot_hvkg_initial_conditions(
605690
) -> Tensor | None:
606691
r"""Generate a batch of smart initializations for qHypervolumeKnowledgeGradient.
607692
608-
This function generates initial conditions for optimizing one-shot HVKG using
609-
the hypervolume maximizing set (of fixed size) under the posterior mean.
610-
Intutively, the hypervolume maximizing set of the fantasized posterior mean
611-
will often be close to a hypervolume maximizing set under the current posterior
612-
mean. This function uses that fact to generate the initial conditions
613-
for the fantasy points. Specifically, a fraction of `1 - frac_random` (see
614-
options) of the restarts are generated by learning the hypervolume maximizing sets
615-
under the current posterior mean, where each hypervolume maximizing set is
616-
obtained from maximizing the hypervolume from a different starting point. Given
617-
a hypervolume maximizing set, the `q` candidate points are selected using to the
618-
standard initialization strategy in `gen_batch_initial_conditions`, with the fixed
619-
hypervolume maximizing set. The remaining `frac_random` restarts fantasy points
620-
as well as all `q` candidate points are chosen according to the standard
621-
initialization strategy in `gen_batch_initial_conditions`.
622-
623-
Args:
624-
acq_function: The qKnowledgeGradient instance to be optimized.
625-
bounds: A `2 x d` tensor of lower and upper bounds for each column of
626-
task features.
627-
q: The number of candidates to consider.
628-
num_restarts: The number of starting points for multistart acquisition
629-
function optimization.
630-
raw_samples: The number of raw samples to consider in the initialization
631-
heuristic.
632-
fixed_features: A map `{feature_index: value}` for features that
633-
should be fixed to a particular value during generation.
634-
options: Options for initial condition generation. These contain all
635-
settings for the standard heuristic initialization from
636-
`gen_batch_initial_conditions`. In addition, they contain
637-
`frac_random` (the fraction of fully random fantasy points),
638-
`num_inner_restarts` and `raw_inner_samples` (the number of random
639-
restarts and raw samples for solving the posterior objective
640-
maximization problem, respectively) and `eta` (temperature parameter
641-
for sampling heuristic from posterior objective maximizers).
642-
inequality constraints: A list of tuples (indices, coefficients, rhs),
643-
with each tuple encoding an inequality constraint of the form
644-
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
645-
equality constraints: A list of tuples (indices, coefficients, rhs),
646-
with each tuple encoding an inequality constraint of the form
647-
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
648-
649-
Returns:
650-
A `num_restarts x q' x d` tensor that can be used as initial conditions
651-
for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number
652-
of points (candidate points plus fantasy points).
653-
654-
Example:
655-
>>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point)
656-
>>> bounds = torch.tensor([[0., 0.], [1., 1.]])
657-
>>> Xinit = gen_one_shot_hvkg_initial_conditions(
658-
>>> qHVKG, bounds, q=3, num_restarts=10, raw_samples=512,
659-
>>> options={"frac_random": 0.25},
660-
>>> )
693+
This function generates initial conditions for optimizing one-shot HVKG using
694+
the hypervolume maximizing set (of fixed size) under the posterior mean.
695+
Intutively, the hypervolume maximizing set of the fantasized posterior mean
696+
will often be close to a hypervolume maximizing set under the current posterior
697+
mean. This function uses that fact to generate the initial conditions
698+
for the fantasy points. Specifically, a fraction of `1 - frac_random` (see
699+
options) of the restarts are generated by learning the hypervolume maximizing sets
700+
under the current posterior mean, where each hypervolume maximizing set is
701+
obtained from maximizing the hypervolume from a different starting point. Given
702+
a hypervolume maximizing set, the `q` candidate points are selected using to the
703+
standard initialization strategy in `gen_batch_initial_conditions`, with the fixed
704+
hypervolume maximizing set. The remaining `frac_random` restarts fantasy points
705+
as well as all `q` candidate points are chosen according to the standard
706+
initialization strategy in `gen_batch_initial_conditions`.
707+
708+
Args:
709+
acq_function: The qKnowledgeGradient instance to be optimized.
710+
bounds: A `2 x d` tensor of lower and upper bounds for each column of
711+
task features.
712+
q: The number of candidates to consider.
713+
num_restarts: The number of starting points for multistart acquisition
714+
function optimization.
715+
raw_samples: The number of raw samples to consider in the initialization
716+
heuristic.
717+
fixed_features: A map `{feature_index: value}` for features that
718+
should be fixed to a particular value during generation.
719+
options: Options for initial condition generation. These contain all
720+
settings for the standard heuristic initialization from
721+
`gen_batch_initial_conditions`. In addition, they contain
722+
`frac_random` (the fraction of fully random fantasy points),
723+
`num_inner_restarts` and `raw_inner_samples` (the number of random
724+
restarts and raw samples for solving the posterior objective
725+
maximization problem, respectively) and `eta` (temperature parameter
726+
for sampling heuristic from posterior objective maximizers).
727+
inequality constraints: A list of tuples (indices, coefficients, rhs),
728+
with each tuple encoding an inequality constraint of the form
729+
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
730+
equality constraints: A list of tuples (indices, coefficients, rhs),
731+
with each tuple encoding an inequality constraint of the form
732+
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
733+
734+
Returns:
735+
A `num_restarts x q' x d` tensor that can be used as initial conditions
736+
for `optimize_acqf()`. Here `q' = q + num_fantasies` is the total number
737+
of points (candidate points plus fantasy points).
738+
739+
gen_batch_initial_conditions Example:
740+
>>> qHVKG = qHypervolumeKnowledgeGradient(model, ref_point)
741+
>>> bounds = torch.tensor([[0., 0.], [1., 1.]])
742+
>>> Xinit = gen_one_shot_hvkg_initial_conditions(
743+
>>> qHVKG, bounds, q=3, num_restarts=10, raw_samples=512,
744+
>>> options={"frac_random": 0.25},
745+
>>> )
661746
"""
662747
from botorch.optim.optimize import optimize_acqf
663748

@@ -1139,6 +1224,7 @@ def sample_points_around_best(
11391224
best_pct: float = 5.0,
11401225
subset_sigma: float = 1e-1,
11411226
prob_perturb: float | None = None,
1227+
best_X: Tensor | None = None,
11421228
) -> Tensor | None:
11431229
r"""Find best points and sample nearby points.
11441230
@@ -1157,60 +1243,62 @@ def sample_points_around_best(
11571243
An optional `n_discrete_points x d`-dim tensor containing the
11581244
sampled points. This is None if no baseline points are found.
11591245
"""
1160-
X = get_X_baseline(acq_function=acq_function)
1161-
if X is None:
1162-
return
1163-
with torch.no_grad():
1164-
try:
1165-
posterior = acq_function.model.posterior(X)
1166-
except AttributeError:
1167-
warnings.warn(
1168-
"Failed to sample around previous best points.",
1169-
BotorchWarning,
1170-
stacklevel=3,
1171-
)
1246+
if best_X is None:
1247+
X = get_X_baseline(acq_function=acq_function)
1248+
if X is None:
11721249
return
1173-
mean = posterior.mean
1174-
while mean.ndim > 2:
1175-
# take average over batch dims
1176-
mean = mean.mean(dim=0)
1177-
try:
1178-
f_pred = acq_function.objective(mean)
1179-
# Some acquisition functions do not have an objective
1180-
# and for some acquisition functions the objective is None
1181-
except (AttributeError, TypeError):
1182-
f_pred = mean
1183-
if hasattr(acq_function, "maximize"):
1184-
# make sure that the optimiztaion direction is set properly
1185-
if not acq_function.maximize:
1186-
f_pred = -f_pred
1187-
try:
1188-
# handle constraints for EHVI-based acquisition functions
1189-
constraints = acq_function.constraints
1190-
if constraints is not None:
1191-
neg_violation = -torch.stack(
1192-
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
1193-
).sum(dim=-1)
1194-
feas = neg_violation == 0
1195-
if feas.any():
1196-
f_pred[~feas] = float("-inf")
1197-
else:
1198-
# set objective equal to negative violation
1199-
f_pred = neg_violation
1200-
except AttributeError:
1201-
pass
1202-
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
1203-
# multi-objective
1204-
# find pareto set
1205-
is_pareto = is_non_dominated(f_pred)
1206-
best_X = X[is_pareto]
1207-
else:
1208-
if f_pred.shape[-1] == 1:
1209-
f_pred = f_pred.squeeze(-1)
1210-
n_best = max(1, round(X.shape[0] * best_pct / 100))
1211-
# the view() is to ensure that best_idcs is not a scalar tensor
1212-
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
1213-
best_X = X[best_idcs]
1250+
with torch.no_grad():
1251+
try:
1252+
posterior = acq_function.model.posterior(X)
1253+
except AttributeError:
1254+
warnings.warn(
1255+
"Failed to sample around previous best points.",
1256+
BotorchWarning,
1257+
stacklevel=3,
1258+
)
1259+
return
1260+
mean = posterior.mean
1261+
while mean.ndim > 2:
1262+
# take average over batch dims
1263+
mean = mean.mean(dim=0)
1264+
try:
1265+
f_pred = acq_function.objective(mean)
1266+
# Some acquisition functions do not have an objective
1267+
# and for some acquisition functions the objective is None
1268+
except (AttributeError, TypeError):
1269+
f_pred = mean
1270+
if hasattr(acq_function, "maximize"):
1271+
# make sure that the optimiztaion direction is set properly
1272+
if not acq_function.maximize:
1273+
f_pred = -f_pred
1274+
try:
1275+
# handle constraints for EHVI-based acquisition functions
1276+
constraints = acq_function.constraints
1277+
if constraints is not None:
1278+
neg_violation = -torch.stack(
1279+
[c(mean).clamp_min(0.0) for c in constraints], dim=-1
1280+
).sum(dim=-1)
1281+
feas = neg_violation == 0
1282+
if feas.any():
1283+
f_pred[~feas] = float("-inf")
1284+
else:
1285+
# set objective equal to negative violation
1286+
f_pred = neg_violation
1287+
except AttributeError:
1288+
pass
1289+
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1:
1290+
# multi-objective
1291+
# find pareto set
1292+
is_pareto = is_non_dominated(f_pred)
1293+
best_X = X[is_pareto]
1294+
else:
1295+
if f_pred.shape[-1] == 1:
1296+
f_pred = f_pred.squeeze(-1)
1297+
n_best = max(1, round(X.shape[0] * best_pct / 100))
1298+
# the view() is to ensure that best_idcs is not a scalar tensor
1299+
best_idcs = torch.topk(f_pred, n_best).indices.view(-1)
1300+
best_X = X[best_idcs]
1301+
12141302
use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None
12151303
n_trunc_normal_points = (
12161304
n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points

botorch/optim/optimize.py

+7
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
AcquisitionFunction,
2121
OneShotAcquisitionFunction,
2222
)
23+
from botorch.acquisition.joint_entropy_search import qJointEntropySearch
2324
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
2425
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
2526
qHypervolumeKnowledgeGradient,
2627
)
28+
from botorch.acquisition.predictive_entropy_search import qPredictiveEntropySearch
2729
from botorch.exceptions import InputDataError, UnsupportedError
2830
from botorch.exceptions.errors import CandidateGenerationError
2931
from botorch.exceptions.warnings import OptimizationWarning
@@ -33,6 +35,7 @@
3335
gen_batch_initial_conditions,
3436
gen_one_shot_hvkg_initial_conditions,
3537
gen_one_shot_kg_initial_conditions,
38+
gen_optimal_input_initial_conditions,
3639
TGenInitialConditions,
3740
)
3841
from botorch.optim.stopping import ExpMAStoppingCriterion
@@ -174,6 +177,10 @@ def get_ic_generator(self) -> TGenInitialConditions:
174177
return gen_one_shot_kg_initial_conditions
175178
elif isinstance(self.acq_function, qHypervolumeKnowledgeGradient):
176179
return gen_one_shot_hvkg_initial_conditions
180+
elif isinstance(
181+
self.acq_function, (qJointEntropySearch, qPredictiveEntropySearch)
182+
):
183+
return gen_optimal_input_initial_conditions
177184
return gen_batch_initial_conditions
178185

179186

0 commit comments

Comments
 (0)