-
Notifications
You must be signed in to change notification settings - Fork 417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance & runtime improvements to info-theoretic acquisition functions (2/N) - AcqOpt initializer #2751
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -468,6 +468,128 @@ def gen_batch_initial_conditions( | |||||
return batch_initial_conditions | ||||||
|
||||||
|
||||||
def gen_optimal_input_initial_conditions( | ||||||
esantorella marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
acq_function: AcquisitionFunction, | ||||||
bounds: Tensor, | ||||||
q: int, | ||||||
num_restarts: int, | ||||||
raw_samples: int, | ||||||
fixed_features: dict[int, float] | None = None, | ||||||
options: dict[str, bool | float | int] | None = None, | ||||||
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | ||||||
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, | ||||||
): | ||||||
r"""Generate a batch of initial conditions for random-restart optimziation of | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
information-theoretic acquisition functions (PES & JES), where sampled optimizers | ||||||
of the posterior constitute good initial guesses for further optimization. A | ||||||
fraction of initial samples (by default: 100%) are drawn as perturbations around | ||||||
`acq.optimal_inputs`. On average, this drastically decreases the runtime of | ||||||
acquisition function optimization and yields higher-valued candidates by acquisition | ||||||
function value. See https://github.com/pytorch/botorch/pull/2751 for more info. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
Args: | ||||||
acq_function: The acquisition function to be optimized. | ||||||
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`. | ||||||
q: The number of candidates to consider. | ||||||
num_restarts: The number of starting points for multistart acquisition | ||||||
function optimization. | ||||||
raw_samples: The number of raw samples to consider in the initialization | ||||||
heuristic. Note: if `sample_around_best` is True (the default is False), | ||||||
then `2 * raw_samples` samples are used. | ||||||
fixed_features: A map `{feature_index: value}` for features that | ||||||
should be fixed to a particular value during generation. | ||||||
options: Options for initial condition generation. These contain all | ||||||
settings for the standard heuristic initialization from | ||||||
`gen_batch_initial_conditions`. In addition, they contain | ||||||
`frac_random` (the fraction of points drawn fully at random as opposed | ||||||
to around the drawn optimizers from the posterior). | ||||||
`sample_around_best_sigma` dictates both the standard deviation of the | ||||||
samples drawn from posterior maximizers, and the samples from previous | ||||||
best (if enabled). | ||||||
inequality constraints: A list of tuples (indices, coefficients, rhs), | ||||||
with each tuple encoding an inequality constraint of the form | ||||||
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. | ||||||
equality constraints: A list of tuples (indices, coefficients, rhs), | ||||||
with each tuple encoding an inequality constraint of the form | ||||||
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`. | ||||||
|
||||||
Returns: | ||||||
A `num_restarts x q x d` tensor of initial conditions. | ||||||
""" | ||||||
options = options or {} | ||||||
device = bounds.device | ||||||
if not hasattr(acq_function, "optimal_inputs"): | ||||||
raise AttributeError( | ||||||
"gen_optimal_input_initial_conditions can only be used with " | ||||||
"an AcquisitionFunction that has an optimal_inputs attribute." | ||||||
) | ||||||
frac_random: float = options.get("frac_random", 0.0) | ||||||
if not 0 <= frac_random <= 1: | ||||||
raise ValueError( | ||||||
f"frac_random must take on values in (0,1). Value: {frac_random}" | ||||||
) | ||||||
|
||||||
batch_limit = options.get("batch_limit") | ||||||
num_optima = acq_function.optimal_inputs.shape[:-1].numel() | ||||||
suggestions = acq_function.optimal_inputs.reshape(num_optima, -1) | ||||||
X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype) | ||||||
num_random = round(raw_samples * frac_random) | ||||||
if num_random > 0: | ||||||
X_rnd = sample_q_batches_from_polytope( | ||||||
n=num_random, | ||||||
q=q, | ||||||
bounds=bounds, | ||||||
n_burnin=options.get("n_burnin", 10000), | ||||||
n_thinning=options.get("n_thinning", 32), | ||||||
equality_constraints=equality_constraints, | ||||||
inequality_constraints=inequality_constraints, | ||||||
) | ||||||
X = torch.cat((X, X_rnd)) | ||||||
|
||||||
if num_random < raw_samples: | ||||||
X_perturbed = sample_points_around_best( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a bit nonintuitive that we do this even when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possibly! My though was, since it is not actually sampling around the incumbent but around the sampled optima, I could keep it and re-use its logic. I tried to mimic the KG logic for it, and that uses |
||||||
acq_function=acq_function, | ||||||
n_discrete_points=q * (raw_samples - num_random), | ||||||
sigma=options.get("sample_around_best_sigma", 1e-2), | ||||||
bounds=bounds, | ||||||
best_X=suggestions, | ||||||
) | ||||||
X_perturbed = X_perturbed.view( | ||||||
raw_samples - num_random, q, bounds.shape[-1] | ||||||
).cpu() | ||||||
X = torch.cat((X, X_perturbed)) | ||||||
|
||||||
if options.get("sample_around_best", False): | ||||||
X_best = sample_points_around_best( | ||||||
acq_function=acq_function, | ||||||
n_discrete_points=q * raw_samples, | ||||||
sigma=options.get("sample_around_best_sigma", 1e-2), | ||||||
bounds=bounds, | ||||||
) | ||||||
X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu() | ||||||
X = torch.cat((X, X_best)) | ||||||
|
||||||
X_rnd = fix_features(X, fixed_features=fixed_features).cpu() | ||||||
with torch.no_grad(): | ||||||
if batch_limit is None: | ||||||
batch_limit = X.shape[0] | ||||||
# Evaluate the acquisition function on `X_rnd` using `batch_limit` | ||||||
# sized chunks. | ||||||
acq_vals = torch.cat( | ||||||
[ | ||||||
acq_function(x_.to(device=device)).cpu() | ||||||
for x_ in X.split(split_size=batch_limit, dim=0) | ||||||
], | ||||||
dim=0, | ||||||
) | ||||||
idx = boltzmann_sample( | ||||||
function_values=acq_vals, | ||||||
num_samples=num_restarts, | ||||||
eta=options.get("eta", 2.0), | ||||||
) | ||||||
return X[idx] | ||||||
|
||||||
|
||||||
def gen_one_shot_kg_initial_conditions( | ||||||
acq_function: qKnowledgeGradient, | ||||||
bounds: Tensor, | ||||||
|
@@ -1136,6 +1258,7 @@ def sample_points_around_best( | |||||
best_pct: float = 5.0, | ||||||
subset_sigma: float = 1e-1, | ||||||
prob_perturb: float | None = None, | ||||||
best_X: Tensor | None = None, | ||||||
) -> Tensor | None: | ||||||
r"""Find best points and sample nearby points. | ||||||
|
||||||
|
@@ -1154,60 +1277,62 @@ def sample_points_around_best( | |||||
An optional `n_discrete_points x d`-dim tensor containing the | ||||||
sampled points. This is None if no baseline points are found. | ||||||
""" | ||||||
X = get_X_baseline(acq_function=acq_function) | ||||||
if X is None: | ||||||
return | ||||||
with torch.no_grad(): | ||||||
try: | ||||||
posterior = acq_function.model.posterior(X) | ||||||
except AttributeError: | ||||||
warnings.warn( | ||||||
"Failed to sample around previous best points.", | ||||||
BotorchWarning, | ||||||
stacklevel=3, | ||||||
) | ||||||
if best_X is None: | ||||||
X = get_X_baseline(acq_function=acq_function) | ||||||
if X is None: | ||||||
return | ||||||
mean = posterior.mean | ||||||
while mean.ndim > 2: | ||||||
# take average over batch dims | ||||||
mean = mean.mean(dim=0) | ||||||
try: | ||||||
f_pred = acq_function.objective(mean) | ||||||
# Some acquisition functions do not have an objective | ||||||
# and for some acquisition functions the objective is None | ||||||
except (AttributeError, TypeError): | ||||||
f_pred = mean | ||||||
if hasattr(acq_function, "maximize"): | ||||||
# make sure that the optimiztaion direction is set properly | ||||||
if not acq_function.maximize: | ||||||
f_pred = -f_pred | ||||||
try: | ||||||
# handle constraints for EHVI-based acquisition functions | ||||||
constraints = acq_function.constraints | ||||||
if constraints is not None: | ||||||
neg_violation = -torch.stack( | ||||||
[c(mean).clamp_min(0.0) for c in constraints], dim=-1 | ||||||
).sum(dim=-1) | ||||||
feas = neg_violation == 0 | ||||||
if feas.any(): | ||||||
f_pred[~feas] = float("-inf") | ||||||
else: | ||||||
# set objective equal to negative violation | ||||||
f_pred = neg_violation | ||||||
except AttributeError: | ||||||
pass | ||||||
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1: | ||||||
# multi-objective | ||||||
# find pareto set | ||||||
is_pareto = is_non_dominated(f_pred) | ||||||
best_X = X[is_pareto] | ||||||
else: | ||||||
if f_pred.shape[-1] == 1: | ||||||
f_pred = f_pred.squeeze(-1) | ||||||
n_best = max(1, round(X.shape[0] * best_pct / 100)) | ||||||
# the view() is to ensure that best_idcs is not a scalar tensor | ||||||
best_idcs = torch.topk(f_pred, n_best).indices.view(-1) | ||||||
best_X = X[best_idcs] | ||||||
with torch.no_grad(): | ||||||
try: | ||||||
posterior = acq_function.model.posterior(X) | ||||||
except AttributeError: | ||||||
warnings.warn( | ||||||
"Failed to sample around previous best points.", | ||||||
BotorchWarning, | ||||||
stacklevel=3, | ||||||
) | ||||||
return | ||||||
mean = posterior.mean | ||||||
while mean.ndim > 2: | ||||||
# take average over batch dims | ||||||
mean = mean.mean(dim=0) | ||||||
try: | ||||||
f_pred = acq_function.objective(mean) | ||||||
# Some acquisition functions do not have an objective | ||||||
# and for some acquisition functions the objective is None | ||||||
except (AttributeError, TypeError): | ||||||
f_pred = mean | ||||||
if hasattr(acq_function, "maximize"): | ||||||
# make sure that the optimiztaion direction is set properly | ||||||
if not acq_function.maximize: | ||||||
f_pred = -f_pred | ||||||
try: | ||||||
# handle constraints for EHVI-based acquisition functions | ||||||
constraints = acq_function.constraints | ||||||
if constraints is not None: | ||||||
neg_violation = -torch.stack( | ||||||
[c(mean).clamp_min(0.0) for c in constraints], dim=-1 | ||||||
).sum(dim=-1) | ||||||
feas = neg_violation == 0 | ||||||
if feas.any(): | ||||||
f_pred[~feas] = float("-inf") | ||||||
else: | ||||||
# set objective equal to negative violation | ||||||
f_pred = neg_violation | ||||||
except AttributeError: | ||||||
pass | ||||||
if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1: | ||||||
# multi-objective | ||||||
# find pareto set | ||||||
is_pareto = is_non_dominated(f_pred) | ||||||
best_X = X[is_pareto] | ||||||
else: | ||||||
if f_pred.shape[-1] == 1: | ||||||
f_pred = f_pred.squeeze(-1) | ||||||
n_best = max(1, round(X.shape[0] * best_pct / 100)) | ||||||
# the view() is to ensure that best_idcs is not a scalar tensor | ||||||
best_idcs = torch.topk(f_pred, n_best).indices.view(-1) | ||||||
best_X = X[best_idcs] | ||||||
Comment on lines
+1327
to
+1334
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
|
||||||
use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None | ||||||
n_trunc_normal_points = ( | ||||||
n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work if the model is a
ModelListGP
, where `train_inputs is a list of tuples of tensors rather than a tuple of tensors: https://github.com/cornellius-gp/gpytorch/blob/b017b9c3fe4de526f7a2243ce12ce2305862c90b/gpytorch/models/model_list.py#L83-L86Any thoughts on what we should do there?