Skip to content

[Bug] batch_initial_conditions shouldn't have to satisfy nonlinear_inequality_constraints #2624

Open
@slishak-PX

Description

🐛 Bug

When using nonlinear_inequality_constraints in optimize_acqf, you need to set batch_initial_conditions, and these ICs need to respect the constraints. This seems unnecessary - SLSQP is capable of starting from an infeasible IC. If the only reason batch_initial_conditions needs to be set is so that the user is forced to provide a feasible IC, then this requirement can be relaxed too.

I imagine the issue can also be solved by using DeterministicModel with an outcome constraint, but this does not work with analytic acquisition functions.

raise ValueError(
"`batch_initial_conditions` must satisfy the non-linear inequality "
"constraints."
)

To reproduce

** Code snippet to reproduce **

import torch
from botorch.acquisition import UpperConfidenceBound
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood

def objective(x):
    return (x[..., 0] - 0.5) ** 2 + x[..., 0]


def constraint(x):
    return (x[..., 0] - 0.5) ** 2 * 50 - 2


n_train = 64
device = torch.device("cpu")

train_x = torch.rand(n_train, 1, dtype=torch.float64, device=device)
train_y = objective(train_x)
con_y = constraint(train_x)

bounds = torch.vstack([torch.zeros(1, 1), torch.ones(1, 1)])

model = SingleTaskGP(
    train_x,
    train_y[:, None],
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
_ = fit_gpytorch_mll(mll)

acqf = UpperConfidenceBound(model, beta=4)

initial_condition = 0.33
candidates, value = optimize_acqf(
    acqf,
    bounds,
    q=1,
    num_restarts=1,
    raw_samples=1,
    nonlinear_inequality_constraints=[
        (lambda x: -constraint(x), True),
    ],
    batch_initial_conditions=torch.tensor([[[initial_condition]]]),
)

** Stack trace/error message **

ValueError: `batch_initial_conditions` must satisfy the non-linear inequality constraints.

Expected Behavior

If the exception is commented out, the same candidate is found regardless of whether initial_condition is feasible or infeasible, demonstrating that in this case the exception is preventing use cases where it is hard to find a feasible region and you want the optimiser to find it for you.

System information

Please complete the following information:

  • BoTorch Version 0.12.0
  • GPyTorch Version 1.13
  • PyTorch Version 2.5.1+cu124
  • Computer OS: Linux

Additional context

NA

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions