Skip to content

Commit

Permalink
pass gen_candidates callable in optimize_acqf
Browse files Browse the repository at this point in the history
Summary: see title. This will support using stochastic optimization

Differential Revision: D41629164

fbshipit-source-id: 5abc48b9bdf5ecc9269792b18e3bcdea99150d4a
  • Loading branch information
sdaulton authored and facebook-github-bot committed Feb 4, 2023
1 parent 71e8db5 commit 2729eba
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
4 changes: 4 additions & 0 deletions botorch/generation/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

logger = _get_logger()

TGenCandidates = Callable[[Tensor, AcquisitionFunction, Any], Tuple[Tensor, Tensor]]


def gen_candidates_scipy(
initial_conditions: Tensor,
Expand All @@ -49,6 +51,7 @@ def gen_candidates_scipy(
options: Optional[Dict[str, Any]] = None,
fixed_features: Optional[Dict[int, Optional[float]]] = None,
timeout_sec: Optional[float] = None,
**kwargs,
) -> Tuple[Tensor, Tensor]:
r"""Generate a set of candidates using `scipy.optimize.minimize`.
Expand Down Expand Up @@ -281,6 +284,7 @@ def gen_candidates_torch(
callback: Optional[Callable[[int, Tensor, Tensor], NoReturn]] = None,
fixed_features: Optional[Dict[int, Optional[float]]] = None,
timeout_sec: Optional[float] = None,
**kwargs,
) -> Tuple[Tensor, Tensor]:
r"""Generate a set of candidates using a `torch.optim` optimizer.
Expand Down
11 changes: 7 additions & 4 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.exceptions import InputDataError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.generation.gen import gen_candidates_scipy
from botorch.generation.gen import gen_candidates_scipy, TGenCandidates
from botorch.logging import logger
from botorch.optim.initializers import (
gen_batch_initial_conditions,
Expand Down Expand Up @@ -64,6 +64,7 @@ def optimize_acqf(
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
batch_initial_conditions: Optional[Tensor] = None,
return_best_only: bool = True,
gen_candidates: TGenCandidates = gen_candidates_scipy,
sequential: bool = False,
**kwargs: Any,
) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -103,6 +104,8 @@ def optimize_acqf(
this if you do not want to use default initialization strategy.
return_best_only: If False, outputs the solutions corresponding to all
random restart initializations of the optimization.
gen_candidates: A callable for generating candidates given initial
conditions. Default: `gen_candidates_scipy`
sequential: If False, uses joint optimization, otherwise uses sequential
optimization.
kwargs: Additonal keyword arguments.
Expand Down Expand Up @@ -273,7 +276,7 @@ def _optimize_batch_candidates(
if timeout_sec is not None:
timeout_sec = (timeout_sec - start_time) / len(batched_ics)

scipy_kws = {
gen_kws = {
"acquisition_function": acq_function,
"lower_bounds": None if bounds[0].isinf().all() else bounds[0],
"upper_bounds": None if bounds[1].isinf().all() else bounds[1],
Expand All @@ -289,8 +292,8 @@ def _optimize_batch_candidates(
# optimize using random restart optimization
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always", category=OptimizationWarning)
batch_candidates_curr, batch_acq_values_curr = gen_candidates_scipy(
initial_conditions=batched_ics_, **scipy_kws
batch_candidates_curr, batch_acq_values_curr = gen_candidates(
initial_conditions=batched_ics_, **gen_kwargs
)
opt_warnings += ws
batch_candidates_list.append(batch_candidates_curr)
Expand Down

0 comments on commit 2729eba

Please sign in to comment.