Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate constraints in optimize_acqf #1231

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 142 additions & 6 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,29 @@

from __future__ import annotations

import warnings

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from botorch.acquisition.acquisition import (
AcquisitionFunction,
OneShotAcquisitionFunction,
)
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.exceptions import InputDataError, UnsupportedError
from botorch.exceptions import InputDataError, OptimizationWarning, UnsupportedError
from botorch.generation.gen import gen_candidates_scipy
from botorch.logging import logger
from botorch.optim.initializers import (
gen_batch_initial_conditions,
gen_one_shot_kg_initial_conditions,
)
from botorch.optim.stopping import ExpMAStoppingCriterion
from scipy.optimize import linprog
from torch import Tensor


INIT_OPTION_KEYS = {
# set of options for initialization that we should
# not pass to scipy.optimize.minimize to avoid
Expand Down Expand Up @@ -62,6 +67,7 @@ def optimize_acqf(
batch_initial_conditions: Optional[Tensor] = None,
return_best_only: bool = True,
sequential: bool = False,
validate_constraints: bool = True,
**kwargs: Any,
) -> Tuple[Tensor, Tensor]:
r"""Generate a set of candidates via multi-start optimization.
Expand All @@ -75,10 +81,10 @@ def optimize_acqf(
raw_samples: The number of samples for initialization. This is required
if `batch_initial_conditions` is not specified.
options: Options for candidate generation.
inequality constraints: A list of tuples (indices, coefficients, rhs),
inequality_constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`
equality constraints: A list of tuples (indices, coefficients, rhs),
equality_constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`
nonlinear_inequality_constraints: A list of callables with that represent
Expand All @@ -100,6 +106,8 @@ def optimize_acqf(
random restart initializations of the optimization.
sequential: If False, uses joint optimization, otherwise uses sequential
optimization.
validate_constraints: If True, validate that the constraint set is
non-empty and bounded by solving a Linear Program.
kwargs: Additonal keyword arguments.

Returns:
Expand All @@ -125,9 +133,11 @@ def optimize_acqf(
>>> qEI, bounds, 3, 15, 256, sequential=True
>>> )
"""
if not (bounds.ndim == 2 and bounds.shape[0] == 2):
raise ValueError(
f"bounds should be a `2 x d` tensor, current shape: {list(bounds.shape)}."
if validate_constraints:
_validate_constraints(
bounds=bounds,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
)

if sequential and q > 1:
Expand Down Expand Up @@ -158,6 +168,7 @@ def optimize_acqf(
batch_initial_conditions=None,
return_best_only=True,
sequential=False,
validate_constraints=False,
)
candidate_list.append(candidate)
acq_value_list.append(acq_value)
Expand Down Expand Up @@ -267,6 +278,7 @@ def optimize_acqf_cyclic(
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
batch_initial_conditions: Optional[Tensor] = None,
cyclic_options: Optional[Dict[str, Union[bool, float, int, str]]] = None,
validate_constraints: bool = True,
) -> Tuple[Tensor, Tensor]:
r"""Generate a set of `q` candidates via cyclic optimization.

Expand Down Expand Up @@ -294,6 +306,8 @@ def optimize_acqf_cyclic(
If no initial conditions are provided, the default initialization will
be used.
cyclic_options: Options for stopping criterion for outer cyclic optimization.
validate_constraints: If True, validate that the constraint set is
non-empty and bounded by solving a Linear Program.

Returns:
A two-element tuple containing
Expand Down Expand Up @@ -328,6 +342,7 @@ def optimize_acqf_cyclic(
batch_initial_conditions=batch_initial_conditions,
return_best_only=True,
sequential=True,
validate_constraints=validate_constraints,
)
if q > 1:
cyclic_options = cyclic_options or {}
Expand Down Expand Up @@ -358,6 +373,7 @@ def optimize_acqf_cyclic(
batch_initial_conditions=candidates[i].unsqueeze(0),
return_best_only=True,
sequential=True,
validate_constraints=False,
)
candidates[i] = candidate_i
acq_vals[i] = acq_val_i
Expand All @@ -377,6 +393,7 @@ def optimize_acqf_list(
equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
fixed_features: Optional[Dict[int, float]] = None,
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
validate_constraints: bool = True,
) -> Tuple[Tensor, Tensor]:
r"""Generate a list of candidates from a list of acquisition functions.

Expand All @@ -402,6 +419,8 @@ def optimize_acqf_list(
post_processing_func: A function that post-processes an optimization
result appropriately (i.e., according to `round-trip`
transformations).
validate_constraints: If True, validate that the constraint set is
non-empty and bounded by solving a Linear Program.

Returns:
A two-element tuple containing
Expand All @@ -413,6 +432,13 @@ def optimize_acqf_list(
"""
if not acq_function_list:
raise ValueError("acq_function_list must be non-empty.")
if validate_constraints:
_validate_constraints(
bounds=bounds,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
)

candidate_list, acq_value_list = [], []
candidates = torch.tensor([], device=bounds.device, dtype=bounds.dtype)
base_X_pending = acq_function_list[0].X_pending
Expand All @@ -436,6 +462,7 @@ def optimize_acqf_list(
post_processing_func=post_processing_func,
return_best_only=True,
sequential=False,
validate_constraints=False,
)
candidate_list.append(candidate)
acq_value_list.append(acq_value)
Expand All @@ -455,6 +482,7 @@ def optimize_acqf_mixed(
equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
batch_initial_conditions: Optional[Tensor] = None,
validate_constraints: bool = True,
**kwargs: Any,
) -> Tuple[Tensor, Tensor]:
r"""Optimize over a list of fixed_features and returns the best solution.
Expand Down Expand Up @@ -485,6 +513,8 @@ def optimize_acqf_mixed(
transformations).
batch_initial_conditions: A tensor to specify the initial conditions. Set
this if you do not want to use default initialization strategy.
validate_constraints: If True, validate that the constraint set is
non-empty and bounded by solving a Linear Program.

Returns:
A two-element tuple containing
Expand All @@ -502,6 +532,12 @@ def optimize_acqf_mixed(
"are currently not supported when `q > 1`. This is needed to "
"compute the joint acquisition value."
)
if validate_constraints:
_validate_constraints(
bounds=bounds,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
)

if q == 1:
ff_candidate_list, ff_acq_value_list = [], []
Expand All @@ -519,6 +555,7 @@ def optimize_acqf_mixed(
post_processing_func=post_processing_func,
batch_initial_conditions=batch_initial_conditions,
return_best_only=True,
validate_constraints=False,
)
ff_candidate_list.append(candidate)
ff_acq_value_list.append(acq_value)
Expand Down Expand Up @@ -707,6 +744,105 @@ def _gen_batch_initial_conditions_local_search(
raise RuntimeError(f"Failed to generate at least {min_points} initial conditions")


def _validate_constraints(
bounds: Tensor,
inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
) -> None:
r"""Validate constraints for acquisition function optimization.

Checks that the constraints define a bounded, non-empty polytope.

Args:
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
If there are no box constraints, bounds should be an empty `0 x d`-dim
tensor.
inequality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`
equality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`
"""
# We solve the following Linear Program to ensure that he constraint set
# is non-empty and bounded:
#
# max_x |x|_1
# s.t. bounds(x)
# inequality_constraints(x)
# equality_constraints(x)
#
# To do this we can introduce auxiliary variables s and solve the
# following standard formulation:
#
# min_(x, s) - sum_i(s_i)
# s.t. -x <= s <= x
Comment on lines +778 to +779
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct - we're trying to solve

max \sum_i |x_i|

and the correct transformation would be

max \sum_i s_i
s_i <= |x_i|

The constraint |x_i| >= s_i is equal to (x_i >= s_i) OR (x_i <= -s_i) and is not convex and so won't be representable as an LP.

The constraint given here, -x _i<= s_i <= x_i is equivalent to |s_i| <= x, which is convex but is not correct for the problem we're trying to solve. In particular it will only provide the correct answer for positive values of x_i; if the inequality and equality constraints are unbounded towards -Inf but has volume in the positive orthant, this LP will not detect the unboundedness.

# bounds(x)
# inequality_constraints(x)
# equality_constraints(x)
#
if bounds.numel() == 0:
if inequality_constraints is None:
raise UnsupportedError(
"Must provide either `bounds` or `inequality_constraints` (or both)."
)
elif not (bounds.ndim == 2 and bounds.shape[0] == 2):
raise ValueError(
f"bounds should be a `2 x d` tensor, current shape: {tuple(bounds.shape)}."
)
d = bounds.shape[-1]
bounds_lp, A_ub, b_ub, A_eq, b_eq = None, None, None, None, None
# The first `d` variables are `x`, the last `d` are the auxiliary `s`
if bounds.numel() > 0:
# `s` is unbounded
bounds_lp = [tuple(b_i) for b_i in bounds.t()] + [(None, None)] * d
# Encode the constraint `-x <= s <= x`
A_ub = np.zeros((2 * d, 2 * d))
b_ub = np.zeros(2 * d)
A_ub[:d, :d] = -1.0
A_ub[:d, d : 2 * d] = -1.0
A_ub[d : 2 * d, :d] = -1.0
A_ub[d : 2 * d, d : 2 * d] = 1.0
# Convet and add additional inequality constraints if present
if inequality_constraints is not None:
A_ineq = np.zeros((len(inequality_constraints), 2 * d))
b_ineq = np.zeros(len(inequality_constraints))
for i, (indices, coefficients, rhs) in enumerate(inequality_constraints):
A_ineq[i, indices] = -coefficients
b_ineq[i] = -rhs
A_ub = np.concatenate((A_ub, A_ineq))
b_ub = np.concatenate((b_ub, b_ineq))
# Convert equality constraints if present
if equality_constraints is not None:
A_eq = np.zeros((len(equality_constraints), 2 * d))
b_eq = np.zeros(len(equality_constraints))
for i, (indices, coefficients, rhs) in enumerate(equality_constraints):
A_eq[i, indices] = coefficients
b_eq[i] = rhs
# Objective is `- sum_i s_i` (note: the `s_i` are guaranteed to be positive)
c = np.concatenate((np.zeros(d), -np.ones(d)))
# Solve the problem
result = linprog(
c=c,
bounds=bounds_lp,
A_ub=A_ub,
b_ub=b_ub,
A_eq=A_eq,
b_eq=b_eq,
)
# Check what's going on if unsuccessful
if not result.success:
if result.status == 2:
raise ValueError("Feasible set non-empty. Check your constraints.")
if result.status == 3:
raise ValueError("Feasible set unbounded.")
warnings.warn(
"Ran into issues when checking for boundedness of feasible set. "
f"Optimizer message: {result.message}.",
OptimizationWarning,
)


def optimize_acqf_discrete_local_search(
acq_function: AcquisitionFunction,
discrete_choices: List[Tensor],
Expand Down
Loading