Skip to content

Commit 93c236e

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Deprecate Models.GPEI registry entry (#3020)
Summary: Pull Request resolved: #3020 This diff deprecates `Models.GPEI`, which has been the default model used in many places in the past. I cleaned up all usage I could find, and updated `Models.GPEI` to point to `Models.BOTORCH_MODULAR` with a deprecation warning. At a later date, we can clean that up with a storage level change to support backwards compatibility to ensure we can continue loading old experiments. Reviewed By: Balandat Differential Revision: D64987622 fbshipit-source-id: 84a63ee07cb148aaa706efdf832be4b5e04c1ad8
1 parent f09a318 commit 93c236e

File tree

6 files changed

+24
-18
lines changed

6 files changed

+24
-18
lines changed

ax/analysis/plotly/tests/test_insample_effects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ def setUp(self) -> None:
3939
transition_criteria=[
4040
MaxTrials(
4141
threshold=1,
42-
transition_to="GPEI",
42+
transition_to="MBM",
4343
)
4444
],
4545
),
4646
GenerationNode(
47-
node_name="GPEI",
47+
node_name="MBM",
4848
model_specs=[
4949
ModelSpec(
5050
model_enum=Models.BOTORCH_MODULAR,

ax/analysis/plotly/tests/test_predicted_effects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ def setUp(self) -> None:
4444
transition_criteria=[
4545
MaxTrials(
4646
threshold=1,
47-
transition_to="GPEI",
47+
transition_to="MBM",
4848
)
4949
],
5050
),
5151
GenerationNode(
52-
node_name="GPEI",
52+
node_name="MBM",
5353
model_specs=[
5454
ModelSpec(
5555
model_enum=Models.BOTORCH_MODULAR,

ax/modelbridge/registry.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class ModelSetup(NamedTuple):
159159
transforms=Cont_X_trans + Y_trans,
160160
standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS,
161161
),
162-
"GPEI": ModelSetup(
162+
"Legacy_GPEI": ModelSetup(
163163
bridge_class=TorchModelBridge,
164164
model_class=BotorchModel,
165165
transforms=Cont_X_trans + Y_trans,
@@ -427,12 +427,11 @@ class Models(ModelRegistryBase):
427427
"""
428428

429429
SOBOL = "Sobol"
430-
GPEI = "GPEI"
431430
FACTORIAL = "Factorial"
432431
SAASBO = "SAASBO"
433432
SAAS_MTGP = "SAAS_MTGP"
434433
THOMPSON = "Thompson"
435-
LEGACY_BOTORCH = "GPEI"
434+
LEGACY_BOTORCH = "Legacy_GPEI"
436435
BOTORCH_MODULAR = "BoTorch"
437436
EMPIRICAL_BAYES_THOMPSON = "EB"
438437
UNIFORM = "Uniform"
@@ -443,6 +442,13 @@ class Models(ModelRegistryBase):
443442
ST_MTGP_NEHVI = "ST_MTGP_NEHVI"
444443
CONTEXT_SACBO = "Contextual_SACBO"
445444

445+
@classmethod
446+
@property
447+
def GPEI(cls) -> Models:
448+
return _deprecated_model_with_warning(
449+
old_model_str="GPEI", new_model=cls.BOTORCH_MODULAR
450+
)
451+
446452
@classmethod
447453
@property
448454
def FULLYBAYESIAN(cls) -> Models:

ax/modelbridge/tests/test_registry.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def test_SAASBO(self) -> None:
103103
)
104104

105105
@mock_botorch_optimize
106-
def test_enum_sobol_GPEI(self) -> None:
107-
"""Tests Sobol and GPEI instantiation through the Models enum."""
106+
def test_enum_sobol_legacy_GPEI(self) -> None:
107+
"""Tests Sobol and Legacy GPEI instantiation through the Models enum."""
108108
exp = get_branin_experiment()
109109
# Check that factory generates a valid sobol modelbridge.
110110
sobol = Models.SOBOL(search_space=exp.search_space)
@@ -115,9 +115,9 @@ def test_enum_sobol_GPEI(self) -> None:
115115
exp.new_batch_trial().add_generator_run(sobol_run).run()
116116
# Check that factory generates a valid GP+EI modelbridge.
117117
exp.optimization_config = get_branin_optimization_config()
118-
gpei = Models.GPEI(experiment=exp, data=exp.fetch_data())
118+
gpei = Models.LEGACY_BOTORCH(experiment=exp, data=exp.fetch_data())
119119
self.assertIsInstance(gpei, TorchModelBridge)
120-
self.assertEqual(gpei._model_key, "GPEI")
120+
self.assertEqual(gpei._model_key, "Legacy_GPEI")
121121
botorch_defaults = "ax.models.torch.botorch_defaults"
122122
# Check that the callable kwargs and the torch kwargs were recorded.
123123
self.assertEqual(
@@ -168,7 +168,7 @@ def test_enum_sobol_GPEI(self) -> None:
168168
},
169169
)
170170
prior_kwargs = {"lengthscale_prior": GammaPrior(6.0, 6.0)}
171-
gpei = Models.GPEI(
171+
gpei = Models.LEGACY_BOTORCH(
172172
experiment=exp,
173173
data=exp.fetch_data(),
174174
search_space=exp.search_space,
@@ -316,10 +316,10 @@ def test_get_model_from_generator_run(self) -> None:
316316
self.assertEqual(initial_sobol.gen(n=1).arms, sobol_after_gen.gen(n=1).arms)
317317
exp.new_trial(generator_run=gr)
318318
# Check restoration of GPEI, to ensure proper restoration of callable kwargs
319-
gpei = Models.GPEI(experiment=exp, data=get_branin_data())
319+
gpei = Models.LEGACY_BOTORCH(experiment=exp, data=get_branin_data())
320320
# Punch GPEI model + bridge kwargs into the Sobol generator run, to avoid
321321
# a slow call to `gpei.gen`, and remove Sobol's model state.
322-
gr._model_key = "GPEI"
322+
gr._model_key = "Legacy_GPEI"
323323
gr._model_kwargs = gpei._model_kwargs
324324
gr._bridge_kwargs = gpei._bridge_kwargs
325325
gr._model_state_after_gen = {}
@@ -496,6 +496,7 @@ def test_deprecated_models(self) -> None:
496496
same check in a couple different ways.
497497
"""
498498
for old_model_str, new_model in [
499+
("GPEI", Models.BOTORCH_MODULAR),
499500
("FULLYBAYESIAN", Models.SAASBO),
500501
("FULLYBAYESIANMOO", Models.SAASBO),
501502
("FULLYBAYESIAN_MTGP", Models.SAAS_MTGP),

ax/modelbridge/tests/test_robust_modelbridge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_mars(self) -> None:
129129
def test_unsupported_model(self) -> None:
130130
exp = get_robust_branin_experiment()
131131
with self.assertRaisesRegex(UnsupportedError, "support robust"):
132-
Models.GPEI(
132+
Models.LEGACY_BOTORCH(
133133
experiment=exp,
134134
data=exp.fetch_data(),
135135
).gen(n=1)

ax/service/ax_client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class AxClient(WithDBSettingsBase, BestPointMixin, InstantiationBase):
157157
158158
torch_device: An optional `torch.device` object, used to choose the device
159159
used for generating new points for trials. Works only for torch-based
160-
models, such as GPEI. Ignored if a `generation_strategy` is passed in
160+
models, such as MBM. Ignored if a `generation_strategy` is passed in
161161
manually. To specify the device for a custom `generation_strategy`,
162162
pass in `torch_device` as part of `model_kwargs`. See
163163
https://ax.dev/tutorials/generation_strategy.html for a tutorial on
@@ -1119,8 +1119,7 @@ def get_feature_importances(self, relative: bool = True) -> AxPlotConfig:
11191119
logger.info(
11201120
f"Model {self.generation_strategy.model} does not implement "
11211121
"`feature_importances`, so it cannot be used to generate "
1122-
"this plot. Only certain models, specifically GPEI, implement "
1123-
"feature importances."
1122+
"this plot. Only certain models, implement feature importances."
11241123
)
11251124

11261125
raise ValueError(

0 commit comments

Comments
 (0)