|
8 | 8 |
|
9 | 9 | from __future__ import annotations |
10 | 10 |
|
11 | | -import math |
12 | 11 | import operator |
13 | 12 | from collections.abc import Callable |
14 | 13 | from functools import partial, reduce |
|
47 | 46 | optimize_acqf_discrete_local_search, |
48 | 47 | optimize_acqf_mixed, |
49 | 48 | ) |
50 | | -from botorch.optim.optimize_mixed import optimize_acqf_mixed_alternating |
| 49 | +from botorch.optim.optimize_mixed import ( |
| 50 | + MAX_CARDINALITY_FOR_LOCAL_SEARCH, |
| 51 | + MAX_CHOICES_ENUMERATE, |
| 52 | + optimize_acqf_mixed_alternating, |
| 53 | + should_use_mixed_alternating_optimizer, |
| 54 | +) |
51 | 55 | from botorch.optim.parameter_constraints import evaluate_feasibility |
52 | 56 | from botorch.utils.constraints import get_outcome_constraint_transforms |
53 | 57 | from pyre_extensions import none_throws |
|
63 | 67 | logger: Logger = get_logger(__name__) |
64 | 68 |
|
65 | 69 |
|
66 | | -# For fully discrete search spaces. |
67 | | -MAX_CHOICES_ENUMERATE = 10_000 |
68 | | -MAX_CARDINALITY_FOR_LOCAL_SEARCH = 100 |
69 | | -# For mixed search spaces. |
70 | | -ALTERNATING_OPTIMIZER_THRESHOLD = 10 |
71 | | - |
72 | | - |
73 | 70 | def determine_optimizer( |
74 | 71 | search_space_digest: SearchSpaceDigest, |
75 | 72 | acqf: AcquisitionFunction | None = None, |
@@ -119,17 +116,12 @@ def determine_optimizer( |
119 | 116 | else: |
120 | 117 | optimizer = "optimize_acqf_discrete" |
121 | 118 | else: |
122 | | - n_combos = math.prod([len(v) for v in discrete_choices.values()]) |
123 | | - # If there are less than `ALTERNATING_OPTIMIZER_THRESHOLD` combinations of |
124 | | - # discrete choices, we will use `optimize_acqf_mixed`, which enumerates all |
125 | | - # discrete combinations and optimizes the continuous features with discrete |
126 | | - # features being fixed. Otherwise, we will use |
127 | | - # `optimize_acqf_mixed_alternating`, which alternates between |
128 | | - # continuous and discrete optimization steps. |
129 | | - if n_combos <= ALTERNATING_OPTIMIZER_THRESHOLD: |
130 | | - optimizer = "optimize_acqf_mixed" |
131 | | - else: |
| 119 | + # For mixed (not fully discrete) search spaces, use the shared utility |
| 120 | + # from BoTorch to determine whether to use mixed alternating optimizer. |
| 121 | + if should_use_mixed_alternating_optimizer(discrete_dims=discrete_choices): |
132 | 122 | optimizer = "optimize_acqf_mixed_alternating" |
| 123 | + else: |
| 124 | + optimizer = "optimize_acqf_mixed" |
133 | 125 | return optimizer |
134 | 126 |
|
135 | 127 |
|
|
0 commit comments