Skip to content

Commit bf6efd7

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Update optimize_acqf_homotopy and refactor mixed optimizer dispatch (facebook#4864)
Summary: X-link: meta-pytorch/botorch#3165 see title. This refactors dispatch to mixed optimizers into a shared utility for botorch and ax. This also cleans up optimize_acqf_homotopy by removing fixed_features_list and leveraging the new dispatch util. This adds support for using optimize_acqf_mixed_alternating. Reviewed By: bletham Differential Revision: D91913317
1 parent cf654b1 commit bf6efd7

1 file changed

Lines changed: 11 additions & 19 deletions

File tree

ax/generators/torch/botorch_modular/acquisition.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from __future__ import annotations
1010

11-
import math
1211
import operator
1312
from collections.abc import Callable
1413
from functools import partial, reduce
@@ -47,7 +46,12 @@
4746
optimize_acqf_discrete_local_search,
4847
optimize_acqf_mixed,
4948
)
50-
from botorch.optim.optimize_mixed import optimize_acqf_mixed_alternating
49+
from botorch.optim.optimize_mixed import (
50+
MAX_CARDINALITY_FOR_LOCAL_SEARCH,
51+
MAX_CHOICES_ENUMERATE,
52+
optimize_acqf_mixed_alternating,
53+
should_use_mixed_alternating_optimizer,
54+
)
5155
from botorch.optim.parameter_constraints import evaluate_feasibility
5256
from botorch.utils.constraints import get_outcome_constraint_transforms
5357
from pyre_extensions import none_throws
@@ -63,13 +67,6 @@
6367
logger: Logger = get_logger(__name__)
6468

6569

66-
# For fully discrete search spaces.
67-
MAX_CHOICES_ENUMERATE = 10_000
68-
MAX_CARDINALITY_FOR_LOCAL_SEARCH = 100
69-
# For mixed search spaces.
70-
ALTERNATING_OPTIMIZER_THRESHOLD = 10
71-
72-
7370
def determine_optimizer(
7471
search_space_digest: SearchSpaceDigest,
7572
acqf: AcquisitionFunction | None = None,
@@ -119,17 +116,12 @@ def determine_optimizer(
119116
else:
120117
optimizer = "optimize_acqf_discrete"
121118
else:
122-
n_combos = math.prod([len(v) for v in discrete_choices.values()])
123-
# If there are less than `ALTERNATING_OPTIMIZER_THRESHOLD` combinations of
124-
# discrete choices, we will use `optimize_acqf_mixed`, which enumerates all
125-
# discrete combinations and optimizes the continuous features with discrete
126-
# features being fixed. Otherwise, we will use
127-
# `optimize_acqf_mixed_alternating`, which alternates between
128-
# continuous and discrete optimization steps.
129-
if n_combos <= ALTERNATING_OPTIMIZER_THRESHOLD:
130-
optimizer = "optimize_acqf_mixed"
131-
else:
119+
# For mixed (not fully discrete) search spaces, use the shared utility
120+
# from BoTorch to determine whether to use mixed alternating optimizer.
121+
if should_use_mixed_alternating_optimizer(discrete_dims=discrete_choices):
132122
optimizer = "optimize_acqf_mixed_alternating"
123+
else:
124+
optimizer = "optimize_acqf_mixed"
133125
return optimizer
134126

135127

0 commit comments

Comments
 (0)