Skip to content

Commit 06c916b

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Rename modeling layer components (facebook#3280)
Summary: Pull Request resolved: facebook#3280 Rename `Model` -> `Generator`, `ModelSpec` -> `GeneratorSpec`, `Modelbridge` -> `Adapter`. This also updates the decoders so that we can load objects stored with the previous names: e.g. decode `Models` as `Generators`. Reviewed By: saitcakmak Differential Revision: D68735059 fbshipit-source-id: bb23acca58cc2f30361c799916aa51fc51db4a33
1 parent 1a831f4 commit 06c916b

File tree

197 files changed

+4185
-4456
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

197 files changed

+4185
-4456
lines changed

ax/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
SumConstraint,
3333
Trial,
3434
)
35-
from ax.modelbridge import Models
35+
from ax.modelbridge import Generators
3636
from ax.service import OptimizationLoop, optimize
3737
from ax.storage import json_load, json_save
3838

@@ -52,7 +52,7 @@
5252
"FixedParameter",
5353
"GeneratorRun",
5454
"Metric",
55-
"Models",
55+
"Generators",
5656
"MultiObjective",
5757
"MultiObjectiveOptimizationConfig",
5858
"Objective",

ax/analysis/healthcheck/constraints_feasibility.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ax.core.generation_strategy_interface import GenerationStrategyInterface
2424
from ax.core.optimization_config import OptimizationConfig
2525
from ax.exceptions.core import UserInputError
26-
from ax.modelbridge.base import ModelBridge
26+
from ax.modelbridge.base import Adapter
2727
from ax.modelbridge.generation_strategy import GenerationStrategy
2828
from ax.modelbridge.transforms.derelativize import Derelativize
2929
from pyre_extensions import assert_is_instance, none_throws
@@ -155,7 +155,7 @@ def compute(
155155

156156
def constraints_feasibility(
157157
optimization_config: OptimizationConfig,
158-
model: ModelBridge,
158+
model: Adapter,
159159
prob_threshold: float = 0.99,
160160
) -> Tuple[bool, pd.DataFrame]:
161161
r"""

ax/analysis/healthcheck/regression_detection_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ax.core.observation import observations_from_data
1616

1717
from ax.exceptions.core import DataRequiredError, UserInputError
18-
from ax.modelbridge.discrete import DiscreteModelBridge
18+
from ax.modelbridge.discrete import DiscreteAdapter
1919
from ax.modelbridge.registry import rel_EB_ashr_trans
2020
from ax.models.discrete.eb_ashr import EBAshr
2121
from pyre_extensions import assert_is_instance
@@ -101,7 +101,7 @@ def compute_regression_probabilities_single_trial(
101101

102102
target_data = Data(df=data.df[data.df["metric_name"].isin(metric_names)])
103103

104-
modelbridge = DiscreteModelBridge(
104+
modelbridge = DiscreteAdapter(
105105
experiment=experiment,
106106
search_space=experiment.search_space,
107107
data=target_data,

ax/analysis/healthcheck/tests/test_constraints_feasibility.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from ax.modelbridge.factory import get_sobol
2727
from ax.modelbridge.generation_node import GenerationNode
2828
from ax.modelbridge.generation_strategy import GenerationStrategy
29-
from ax.modelbridge.model_spec import ModelSpec
30-
from ax.modelbridge.registry import Models
29+
from ax.modelbridge.model_spec import GeneratorSpec
30+
from ax.modelbridge.registry import Generators
3131
from ax.utils.common.testutils import TestCase
3232
from ax.utils.testing.core_stubs import (
3333
get_branin_experiment,
@@ -96,8 +96,8 @@ def setUp(self) -> None:
9696
GenerationNode(
9797
node_name="gn",
9898
model_specs=[
99-
ModelSpec(
100-
model_enum=Models.BOTORCH_MODULAR,
99+
GeneratorSpec(
100+
model_enum=Generators.BOTORCH_MODULAR,
101101
)
102102
],
103103
)

ax/analysis/plotly/arm_effects/insample_effects.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
from ax.core.generator_run import GeneratorRun
2323
from ax.core.outcome_constraint import OutcomeConstraint
2424
from ax.exceptions.core import DataRequiredError, UserInputError
25-
from ax.modelbridge.base import ModelBridge
25+
from ax.modelbridge.base import Adapter
2626
from ax.modelbridge.generation_strategy import GenerationStrategy
27-
from ax.modelbridge.registry import Models
27+
from ax.modelbridge.registry import Generators
2828
from ax.modelbridge.transforms.derelativize import Derelativize
2929
from ax.utils.common.logger import get_logger
3030
from pyre_extensions import none_throws
@@ -157,7 +157,7 @@ def _plot_type_string(self) -> str:
157157
return "Modeled" if self.use_modeled_effects else "Observed"
158158

159159

160-
def _get_max_observed_trial_index(model: ModelBridge) -> int | None:
160+
def _get_max_observed_trial_index(model: Adapter) -> int | None:
161161
"""Returns the max observed trial index to appease multitask models for prediction
162162
by giving fixed features. This is not necessarily accurate and should eventually
163163
come from the generation strategy.
@@ -178,7 +178,7 @@ def _get_model(
178178
use_modeled_effects: bool,
179179
trial_index: int,
180180
metric_name: str,
181-
) -> ModelBridge:
181+
) -> Adapter:
182182
"""Get a model for predictions.
183183
184184
Args:
@@ -213,14 +213,14 @@ def _get_model(
213213

214214
if model is None or not is_predictive(model=model):
215215
logger.info("Using empirical Bayes for predictions.")
216-
return Models.EMPIRICAL_BAYES_THOMPSON(
216+
return Generators.EMPIRICAL_BAYES_THOMPSON(
217217
experiment=experiment, data=trial_data
218218
)
219219

220220
return model
221221
else:
222222
# This model just predicts observed data
223-
return Models.THOMPSON(
223+
return Generators.THOMPSON(
224224
data=trial_data,
225225
search_space=experiment.search_space,
226226
experiment=experiment,
@@ -229,7 +229,7 @@ def _get_model(
229229

230230
def _prepare_data(
231231
experiment: Experiment,
232-
model: ModelBridge,
232+
model: Adapter,
233233
outcome_constraints: list[OutcomeConstraint],
234234
metric_name: str,
235235
trial_index: int,
@@ -249,7 +249,7 @@ def _prepare_data(
249249
250250
Args:
251251
experiment: Experiment to plot
252-
model: ModelBridge being used for prediction
252+
model: Adapter being used for prediction
253253
outcome_constraints: Derelatives outcome constraints used for
254254
assessing feasibility
255255
metric_name: Name of metric to plot

ax/analysis/plotly/arm_effects/predicted_effects.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ax.core.experiment import Experiment
2424
from ax.core.generation_strategy_interface import GenerationStrategyInterface
2525
from ax.exceptions.core import UserInputError
26-
from ax.modelbridge.base import ModelBridge
26+
from ax.modelbridge.base import Adapter
2727
from ax.modelbridge.generation_strategy import GenerationStrategy
2828
from ax.modelbridge.transforms.derelativize import Derelativize
2929
from pyre_extensions import assert_is_instance, none_throws
@@ -149,7 +149,7 @@ def compute(
149149

150150

151151
def _prepare_data(
152-
model: ModelBridge,
152+
model: Adapter,
153153
metric_name: str,
154154
candidate_trial: BaseTrial,
155155
outcome_constraints: list[OutcomeConstraint],
@@ -167,7 +167,7 @@ def _prepare_data(
167167
candidate trial.
168168
169169
Args:
170-
model: ModelBridge being used for prediction
170+
model: Adapter being used for prediction
171171
metric_name: Name of metric to plot
172172
candidate_trial: Trial to plot candidates for by generator run
173173
"""

ax/analysis/plotly/arm_effects/utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ax.core.outcome_constraint import OutcomeConstraint
2020
from ax.core.types import TParameterization
2121
from ax.exceptions.core import UserInputError
22-
from ax.modelbridge.base import ModelBridge
22+
from ax.modelbridge.base import Adapter
2323
from ax.modelbridge.prediction_utils import predict_at_point
2424
from plotly import express as px, graph_objects as go
2525
from pyre_extensions import none_throws
@@ -203,7 +203,7 @@ def _add_style_to_effects_by_arm_plot(
203203
)
204204

205205

206-
def _get_trial_index_for_predictions(model: ModelBridge) -> int | None:
206+
def _get_trial_index_for_predictions(model: Adapter) -> int | None:
207207
"""Returns status quo features index if defined on the model. Otherwise, returns
208208
the max observed trial index to appease multitask models for prediction
209209
by giving fixed features. The max index is not necessarily accurate and should
@@ -224,7 +224,7 @@ def _get_trial_index_for_predictions(model: ModelBridge) -> int | None:
224224

225225

226226
def get_predictions_by_arm(
227-
model: ModelBridge,
227+
model: Adapter,
228228
metric_name: str,
229229
outcome_constraints: list[OutcomeConstraint],
230230
gr: GeneratorRun | None = None,

ax/analysis/plotly/cross_validation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
folds: Number of subsamples to partition observations into. Use -1 for
5959
leave-one-out cross validation.
6060
untransform: Whether to untransform the model predictions before cross
61-
validating. Models are trained on transformed data, and candidate
61+
validating. Generators are trained on transformed data, and candidate
6262
generation is performed in the transformed space. Computing the model
6363
quality metric based on the cross-validation results in the
6464
untransformed space may not be representative of the model that

ax/analysis/plotly/interaction.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from ax.core.experiment import Experiment
2828
from ax.core.generation_strategy_interface import GenerationStrategyInterface
2929
from ax.exceptions.core import UserInputError
30-
from ax.modelbridge.registry import Models
31-
from ax.modelbridge.torch import TorchModelBridge
30+
from ax.modelbridge.registry import Generators
31+
from ax.modelbridge.torch import TorchAdapter
3232
from ax.models.torch.botorch_modular.surrogate import Surrogate
3333
from ax.utils.common.logger import get_logger
3434
from ax.utils.sensitivity.sobol_measures import ax_parameter_sens
@@ -261,9 +261,7 @@ def compute(
261261
fig=fig,
262262
)
263263

264-
def _get_oak_model(
265-
self, experiment: Experiment, metric_name: str
266-
) -> TorchModelBridge:
264+
def _get_oak_model(self, experiment: Experiment, metric_name: str) -> TorchAdapter:
267265
"""
268266
Retrieves the modelbridge used for the analysis. The model uses an OAK
269267
(Orthogonal Additive Kernel) with a sparsity-inducing prior,
@@ -275,7 +273,7 @@ def _get_oak_model(
275273
lengthscales being fit.
276274
"""
277275
data = experiment.lookup_data().filter(metric_names=[metric_name])
278-
model_bridge = Models.BOTORCH_MODULAR(
276+
model_bridge = Generators.BOTORCH_MODULAR(
279277
search_space=experiment.search_space,
280278
experiment=experiment,
281279
data=data,
@@ -304,12 +302,12 @@ def _get_oak_model(
304302
),
305303
)
306304

307-
return assert_is_instance(model_bridge, TorchModelBridge)
305+
return assert_is_instance(model_bridge, TorchAdapter)
308306

309307

310308
def _prepare_surface_plot(
311309
experiment: Experiment,
312-
model: TorchModelBridge,
310+
model: TorchAdapter,
313311
feature_name: str,
314312
metric_name: str,
315313
) -> go.Figure:

ax/analysis/plotly/surface/contour.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ax.core.generation_strategy_interface import GenerationStrategyInterface
2323
from ax.core.observation import ObservationFeatures
2424
from ax.exceptions.core import UserInputError
25-
from ax.modelbridge.base import ModelBridge
25+
from ax.modelbridge.base import Adapter
2626
from ax.modelbridge.generation_strategy import GenerationStrategy
2727
from plotly import graph_objects as go
2828
from pyre_extensions import none_throws
@@ -113,7 +113,7 @@ def compute(
113113

114114
def _prepare_data(
115115
experiment: Experiment,
116-
model: ModelBridge,
116+
model: Adapter,
117117
x_parameter_name: str,
118118
y_parameter_name: str,
119119
metric_name: str,

ax/analysis/plotly/surface/slice.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ax.core.generation_strategy_interface import GenerationStrategyInterface
2323
from ax.core.observation import ObservationFeatures
2424
from ax.exceptions.core import UserInputError
25-
from ax.modelbridge.base import ModelBridge
25+
from ax.modelbridge.base import Adapter
2626
from ax.modelbridge.generation_strategy import GenerationStrategy
2727
from plotly import express as px, graph_objects as go
2828
from pyre_extensions import none_throws
@@ -100,7 +100,7 @@ def compute(
100100

101101
def _prepare_data(
102102
experiment: Experiment,
103-
model: ModelBridge,
103+
model: Adapter,
104104
parameter_name: str,
105105
metric_name: str,
106106
) -> pd.DataFrame:

ax/analysis/plotly/tests/test_predicted_effects.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ax.exceptions.core import UserInputError
1818
from ax.modelbridge.dispatch_utils import choose_generation_strategy
1919
from ax.modelbridge.prediction_utils import predict_at_point
20-
from ax.modelbridge.registry import Models
20+
from ax.modelbridge.registry import Generators
2121
from ax.utils.common.testutils import TestCase
2222
from ax.utils.testing.core_stubs import (
2323
get_branin_experiment,
@@ -315,7 +315,7 @@ def test_it_works_for_non_batch_experiments(self) -> None:
315315
experiment=experiment,
316316
)
317317
# AND GIVEN we generate all Sobol trials and one GPEI trial
318-
sobol_key = Models.SOBOL.value
318+
sobol_key = Generators.SOBOL.value
319319
last_model_key = sobol_key
320320
while last_model_key == sobol_key:
321321
trial = experiment.new_trial(

ax/analysis/plotly/tests/test_scatter.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ax.analysis.analysis import AnalysisCardLevel
99
from ax.analysis.plotly.scatter import _prepare_data, ScatterPlot
1010
from ax.exceptions.core import DataRequiredError, UserInputError
11-
from ax.modelbridge.registry import Models
11+
from ax.modelbridge.registry import Generators
1212
from ax.utils.common.testutils import TestCase
1313
from ax.utils.testing.core_stubs import (
1414
get_branin_experiment_with_multi_objective,
@@ -85,7 +85,7 @@ def test_prepare_data(self) -> None:
8585
def test_it_only_has_observations_with_data_for_both_metrics(self) -> None:
8686
# GIVEN an experiment with multiple trials and metrics
8787
experiment = get_branin_experiment_with_multi_objective()
88-
sobol = Models.SOBOL(search_space=experiment.search_space)
88+
sobol = Generators.SOBOL(search_space=experiment.search_space)
8989

9090
t0 = experiment.new_batch_trial(generator_run=sobol.gen(3)).mark_completed(
9191
unsafe=True
@@ -125,7 +125,7 @@ def test_it_only_has_observations_with_data_for_both_metrics(self) -> None:
125125
def test_it_must_have_some_observations_with_data_for_both_metrics(self) -> None:
126126
# GIVEN an experiment with multiple trials and metrics
127127
experiment = get_branin_experiment_with_multi_objective()
128-
sobol = Models.SOBOL(search_space=experiment.search_space)
128+
sobol = Generators.SOBOL(search_space=experiment.search_space)
129129

130130
t0 = experiment.new_batch_trial(generator_run=sobol.gen(3)).mark_completed(
131131
unsafe=True

ax/analysis/plotly/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ax.core.objective import MultiObjective, ScalarizedObjective
1212
from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint
1313
from ax.exceptions.core import UnsupportedError, UserInputError
14-
from ax.modelbridge.base import ModelBridge
14+
from ax.modelbridge.base import Adapter
1515
from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds
1616
from numpy.typing import NDArray
1717

@@ -126,7 +126,7 @@ def format_constraint_violated_probabilities(
126126
return constraints_violated_str
127127

128128

129-
def is_predictive(model: ModelBridge) -> bool:
129+
def is_predictive(model: Adapter) -> bool:
130130
"""Check if a model is predictive. Basically, we're checking if
131131
predict() is implemented.
132132

ax/benchmark/benchmark_test_functions/surrogate.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
1313
from ax.core.observation import ObservationFeatures
1414
from ax.core.types import TParamValue
15-
from ax.modelbridge.torch import TorchModelBridge
15+
from ax.modelbridge.torch import TorchAdapter
1616
from ax.utils.common.base import Base
1717
from ax.utils.common.equality import equality_typechecker
1818
from pyre_extensions import none_throws
@@ -28,7 +28,7 @@ class SurrogateTestFunction(BenchmarkTestFunction):
2828
name: The name of the runner.
2929
outcome_names: Names of outcomes to return in `evaluate_true`, if the
3030
surrogate produces more outcomes than are needed.
31-
_surrogate: Either `None`, or a `TorchModelBridge` surrogate to use
31+
_surrogate: Either `None`, or a `TorchAdapter` surrogate to use
3232
for generating observations. If `None`, `get_surrogate`
3333
must not be None and will be used to generate the surrogate when it
3434
is needed.
@@ -39,8 +39,8 @@ class SurrogateTestFunction(BenchmarkTestFunction):
3939

4040
name: str
4141
outcome_names: Sequence[str]
42-
_surrogate: TorchModelBridge | None = None
43-
get_surrogate: None | Callable[[], TorchModelBridge] = None
42+
_surrogate: TorchAdapter | None = None
43+
get_surrogate: None | Callable[[], TorchAdapter] = None
4444

4545
def __post_init__(self) -> None:
4646
if self.get_surrogate is None and self._surrogate is None:
@@ -50,7 +50,7 @@ def __post_init__(self) -> None:
5050
)
5151

5252
@property
53-
def surrogate(self) -> TorchModelBridge:
53+
def surrogate(self) -> TorchAdapter:
5454
if self._surrogate is None:
5555
self._surrogate = none_throws(self.get_surrogate)()
5656
return none_throws(self._surrogate)

0 commit comments

Comments
 (0)