Skip to content

Commit 888ce3c

Browse files
mpolson64facebook-github-bot
authored andcommitted
Move new dispatch utils out of preview
Summary: As titled. Also refactored slightly such that we wont be importing from ax.api anywhere in the codebase. To keep our module structure easy to reason about it is very important to keep the ax.api module at the root of our dep tree. Differential Revision: D70647193
1 parent 590e546 commit 888ce3c

File tree

7 files changed

+388
-392
lines changed

7 files changed

+388
-392
lines changed

ax/generation_strategy/dispatch_utils.py

+203
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import logging
1010
import warnings
11+
from enum import Enum
1112
from math import ceil
1213
from typing import Any, cast
1314

@@ -16,10 +17,15 @@
1617
from ax.core.optimization_config import OptimizationConfig
1718
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
1819
from ax.core.search_space import SearchSpace
20+
from ax.core.trial_status import TrialStatus
21+
from ax.exceptions.core import UnsupportedError
1922
from ax.generation_strategy.generation_strategy import (
23+
GenerationNode,
2024
GenerationStep,
2125
GenerationStrategy,
2226
)
27+
from ax.generation_strategy.model_spec import GeneratorSpec
28+
from ax.generation_strategy.transition_criterion import MinTrials
2329
from ax.modelbridge.registry import (
2430
Generators,
2531
MODEL_KEY_TO_MODEL_SETUP,
@@ -30,10 +36,13 @@
3036
from ax.models.torch.botorch_modular.model import (
3137
BoTorchGenerator as ModularBoTorchGenerator,
3238
)
39+
from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
3340
from ax.models.types import TConfig
3441
from ax.models.winsorization_config import WinsorizationConfig
3542
from ax.utils.common.deprecation import _validate_force_random_search
3643
from ax.utils.common.logger import get_logger
44+
from botorch.models.transforms.input import Normalize, Warp
45+
from gpytorch.kernels.linear_kernel import LinearKernel
3746
from pyre_extensions import none_throws
3847

3948

@@ -54,6 +63,200 @@
5463
)
5564

5665

66+
class GenerationMethod(Enum):
67+
"""An enum to specify the desired candidate generation method for the experiment.
68+
This is used in ``GenerationStrategyConfig``, along with the properties of the
69+
experiment, to determine the generation strategy to use for candidate generation.
70+
71+
NOTE: New options should be rarely added to this enum. This is not intended to be
72+
a list of generation strategies for the user to choose from. Instead, this enum
73+
should only provide high level guidance to the underlying generation strategy
74+
dispatch logic, which is responsible for determinining the exact details.
75+
76+
Available options are:
77+
BALANCED: A balanced generation method that may utilize (per-metric) model
78+
selection to achieve a good model accuracy. This method excludes expensive
79+
methods, such as the fully Bayesian SAASBO model. Used by default.
80+
FAST: A faster generation method that uses the built-in defaults from the
81+
Modular BoTorch Model without any model selection.
82+
RANDOM_SEARCH: Primarily intended for pure exploration experiments, this
83+
method utilizes quasi-random Sobol sequences for candidate generation.
84+
"""
85+
86+
BALANCED = "balanced"
87+
FAST = "fast"
88+
RANDOM_SEARCH = "random_search"
89+
90+
91+
def _get_sobol_node(
92+
initialization_budget: int | None = None,
93+
initialization_random_seed: int | None = None,
94+
use_existing_trials_for_initialization: bool = True,
95+
min_observed_initialization_trials: int | None = None,
96+
allow_exceeding_initialization_budget: bool = False,
97+
) -> GenerationNode:
98+
"""Constructs a Sobol node based on inputs from ``gs_config``.
99+
The Sobol generator utilizes `initialization_random_seed` if specified.
100+
101+
This node always transitions to "MBM", using the following transition criteria:
102+
- MinTrials enforcing the initialization budget.
103+
- If the initialization budget is not specified, it defaults to 5.
104+
- The TC will not block generation if `allow_exceeding_initialization_budget`
105+
is set to True.
106+
- The TC is currently not restricted to any trial statuses and will
107+
count all trials.
108+
- `use_existing_trials_for_initialization` controls whether trials previously
109+
attached to the experiment are counted as part of the initialization budget.
110+
- MinTrials enforcing the minimum number of observed initialization trials.
111+
- If `min_observed_initialization_trials` is not specified, it defaults
112+
to `max(1, initialization_budget // 2)`.
113+
- The TC currently only counts trials in status COMPLETED (with data attached)
114+
as observed trials.
115+
- `use_existing_trials_for_initialization` controls whether trials previously
116+
attached to the experiment are counted as part of the required number of
117+
observed initialization trials.
118+
"""
119+
# Set the default options.
120+
if initialization_budget is None:
121+
initialization_budget = 5
122+
if min_observed_initialization_trials is None:
123+
min_observed_initialization_trials = max(1, initialization_budget // 2)
124+
# Construct the transition criteria.
125+
transition_criteria = [
126+
MinTrials( # This represents the initialization budget.
127+
threshold=initialization_budget,
128+
transition_to="MBM",
129+
block_gen_if_met=(not allow_exceeding_initialization_budget),
130+
block_transition_if_unmet=True,
131+
use_all_trials_in_exp=use_existing_trials_for_initialization,
132+
),
133+
MinTrials( # This represents minimum observed trials requirement.
134+
threshold=min_observed_initialization_trials,
135+
transition_to="MBM",
136+
block_gen_if_met=False,
137+
block_transition_if_unmet=True,
138+
use_all_trials_in_exp=use_existing_trials_for_initialization,
139+
only_in_statuses=[TrialStatus.COMPLETED],
140+
count_only_trials_with_data=True,
141+
),
142+
]
143+
return GenerationNode(
144+
node_name="Sobol",
145+
model_specs=[
146+
GeneratorSpec(
147+
model_enum=Generators.SOBOL,
148+
model_kwargs={"seed": initialization_random_seed},
149+
)
150+
],
151+
transition_criteria=transition_criteria,
152+
should_deduplicate=True,
153+
)
154+
155+
156+
def _get_mbm_node(
157+
method: GenerationMethod = GenerationMethod.FAST,
158+
torch_device: str | None = None,
159+
) -> GenerationNode:
160+
"""Constructs an MBM node based on the method specified in ``gs_config``.
161+
162+
The ``SurrogateSpec`` takes the following form for the given method:
163+
- BALANCED: Two model configs: one with MBM defaults, the other with
164+
linear kernel with input warping.
165+
- FAST: An empty model config that utilizes MBM defaults.
166+
"""
167+
# Construct the surrogate spec.
168+
if method == GenerationMethod.FAST:
169+
model_configs = [ModelConfig(name="MBM defaults")]
170+
elif method == GenerationMethod.BALANCED:
171+
model_configs = [
172+
ModelConfig(name="MBM defaults"),
173+
ModelConfig(
174+
covar_module_class=LinearKernel,
175+
input_transform_classes=[Warp, Normalize],
176+
input_transform_options={"Normalize": {"center": 0.0}},
177+
name="LinearKernel with Warp",
178+
),
179+
]
180+
else:
181+
raise UnsupportedError(f"Unsupported generation method: {method}.")
182+
183+
return GenerationNode(
184+
node_name="MBM",
185+
model_specs=[
186+
GeneratorSpec(
187+
model_enum=Generators.BOTORCH_MODULAR,
188+
model_kwargs={
189+
"surrogate_spec": SurrogateSpec(model_configs=model_configs),
190+
"torch_device": None
191+
if torch_device is None
192+
else torch.device(torch_device),
193+
},
194+
)
195+
],
196+
should_deduplicate=True,
197+
)
198+
199+
200+
def choose_generation_strategy(
201+
method: GenerationMethod = GenerationMethod.FAST,
202+
# Initialization options
203+
initialization_budget: int | None = None,
204+
initialization_random_seed: int | None = None,
205+
use_existing_trials_for_initialization: bool = True,
206+
min_observed_initialization_trials: int | None = None,
207+
allow_exceeding_initialization_budget: bool = False,
208+
# Misc options
209+
torch_device: str | None = None,
210+
) -> GenerationStrategy:
211+
"""Choose a generation strategy based on the properties of the experiment
212+
and the inputs provided in ``gs_config``.
213+
214+
NOTE: The behavior of this function is subject to change. It will be updated to
215+
produce best general purpose generation strategies based on benchmarking results.
216+
217+
Args:
218+
gs_config: A ``GenerationStrategyConfig`` object that informs
219+
the choice of generation strategy.
220+
221+
Returns:
222+
A generation strategy.
223+
"""
224+
# Handle the random search case.
225+
if method == GenerationMethod.RANDOM_SEARCH:
226+
return GenerationStrategy(
227+
name="QuasiRandomSearch",
228+
nodes=[
229+
GenerationNode(
230+
node_name="Sobol",
231+
model_specs=[
232+
GeneratorSpec(
233+
model_enum=Generators.SOBOL,
234+
model_kwargs={"seed": initialization_random_seed},
235+
)
236+
],
237+
)
238+
],
239+
)
240+
# Construct the nodes.
241+
sobol_node = _get_sobol_node(
242+
initialization_budget=initialization_budget,
243+
initialization_random_seed=initialization_random_seed,
244+
use_existing_trials_for_initialization=use_existing_trials_for_initialization,
245+
min_observed_initialization_trials=min_observed_initialization_trials,
246+
allow_exceeding_initialization_budget=allow_exceeding_initialization_budget,
247+
)
248+
# Construct the MBM node.
249+
mbm_node = _get_mbm_node(
250+
method=method,
251+
torch_device=torch_device,
252+
)
253+
254+
return GenerationStrategy(
255+
name=f"Sobol+MBM:{method.value}",
256+
nodes=[sobol_node, mbm_node],
257+
)
258+
259+
57260
def _make_sobol_step(
58261
num_trials: int = -1,
59262
min_trials_observed: int | None = None,

0 commit comments

Comments
 (0)