[Bug] batch_initial_conditions
shouldn't have to satisfy nonlinear_inequality_constraints
#2624
Description
🐛 Bug
When using nonlinear_inequality_constraints
in optimize_acqf
, you need to set batch_initial_conditions
, and these ICs need to respect the constraints. This seems unnecessary - SLSQP is capable of starting from an infeasible IC. If the only reason batch_initial_conditions
needs to be set is so that the user is forced to provide a feasible IC, then this requirement can be relaxed too.
I imagine the issue can also be solved by using DeterministicModel
with an outcome constraint, but this does not work with analytic acquisition functions.
botorch/botorch/optim/parameter_constraints.py
Lines 593 to 596 in 92d73e4
To reproduce
** Code snippet to reproduce **
import torch
from botorch.acquisition import UpperConfidenceBound
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood
def objective(x):
return (x[..., 0] - 0.5) ** 2 + x[..., 0]
def constraint(x):
return (x[..., 0] - 0.5) ** 2 * 50 - 2
n_train = 64
device = torch.device("cpu")
train_x = torch.rand(n_train, 1, dtype=torch.float64, device=device)
train_y = objective(train_x)
con_y = constraint(train_x)
bounds = torch.vstack([torch.zeros(1, 1), torch.ones(1, 1)])
model = SingleTaskGP(
train_x,
train_y[:, None],
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
_ = fit_gpytorch_mll(mll)
acqf = UpperConfidenceBound(model, beta=4)
initial_condition = 0.33
candidates, value = optimize_acqf(
acqf,
bounds,
q=1,
num_restarts=1,
raw_samples=1,
nonlinear_inequality_constraints=[
(lambda x: -constraint(x), True),
],
batch_initial_conditions=torch.tensor([[[initial_condition]]]),
)
** Stack trace/error message **
ValueError: `batch_initial_conditions` must satisfy the non-linear inequality constraints.
Expected Behavior
If the exception is commented out, the same candidate is found regardless of whether initial_condition
is feasible or infeasible, demonstrating that in this case the exception is preventing use cases where it is hard to find a feasible region and you want the optimiser to find it for you.
System information
Please complete the following information:
- BoTorch Version 0.12.0
- GPyTorch Version 1.13
- PyTorch Version 2.5.1+cu124
- Computer OS: Linux
Additional context
NA