Skip to content

Commit

Permalink
Deprecate Models.GPEI registry entry (#3020)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3020

This diff deprecates `Models.GPEI`, which has been the default model used in many places in the past. I cleaned up all usage I could find, and updated `Models.GPEI` to point to `Models.BOTORCH_MODULAR` with a deprecation warning. At a later date, we can clean that up with a storage level change to support backwards compatibility to ensure we can continue loading old experiments.

Reviewed By: Balandat

Differential Revision: D64987622

fbshipit-source-id: 84a63ee07cb148aaa706efdf832be4b5e04c1ad8
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 5, 2024
1 parent f09a318 commit 93c236e
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 18 deletions.
4 changes: 2 additions & 2 deletions ax/analysis/plotly/tests/test_insample_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def setUp(self) -> None:
transition_criteria=[
MaxTrials(
threshold=1,
transition_to="GPEI",
transition_to="MBM",
)
],
),
GenerationNode(
node_name="GPEI",
node_name="MBM",
model_specs=[
ModelSpec(
model_enum=Models.BOTORCH_MODULAR,
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/plotly/tests/test_predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def setUp(self) -> None:
transition_criteria=[
MaxTrials(
threshold=1,
transition_to="GPEI",
transition_to="MBM",
)
],
),
GenerationNode(
node_name="GPEI",
node_name="MBM",
model_specs=[
ModelSpec(
model_enum=Models.BOTORCH_MODULAR,
Expand Down
12 changes: 9 additions & 3 deletions ax/modelbridge/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class ModelSetup(NamedTuple):
transforms=Cont_X_trans + Y_trans,
standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS,
),
"GPEI": ModelSetup(
"Legacy_GPEI": ModelSetup(
bridge_class=TorchModelBridge,
model_class=BotorchModel,
transforms=Cont_X_trans + Y_trans,
Expand Down Expand Up @@ -427,12 +427,11 @@ class Models(ModelRegistryBase):
"""

SOBOL = "Sobol"
GPEI = "GPEI"
FACTORIAL = "Factorial"
SAASBO = "SAASBO"
SAAS_MTGP = "SAAS_MTGP"
THOMPSON = "Thompson"
LEGACY_BOTORCH = "GPEI"
LEGACY_BOTORCH = "Legacy_GPEI"
BOTORCH_MODULAR = "BoTorch"
EMPIRICAL_BAYES_THOMPSON = "EB"
UNIFORM = "Uniform"
Expand All @@ -443,6 +442,13 @@ class Models(ModelRegistryBase):
ST_MTGP_NEHVI = "ST_MTGP_NEHVI"
CONTEXT_SACBO = "Contextual_SACBO"

@classmethod
@property
def GPEI(cls) -> Models:
return _deprecated_model_with_warning(
old_model_str="GPEI", new_model=cls.BOTORCH_MODULAR
)

@classmethod
@property
def FULLYBAYESIAN(cls) -> Models:
Expand Down
15 changes: 8 additions & 7 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def test_SAASBO(self) -> None:
)

@mock_botorch_optimize
def test_enum_sobol_GPEI(self) -> None:
"""Tests Sobol and GPEI instantiation through the Models enum."""
def test_enum_sobol_legacy_GPEI(self) -> None:
"""Tests Sobol and Legacy GPEI instantiation through the Models enum."""
exp = get_branin_experiment()
# Check that factory generates a valid sobol modelbridge.
sobol = Models.SOBOL(search_space=exp.search_space)
Expand All @@ -115,9 +115,9 @@ def test_enum_sobol_GPEI(self) -> None:
exp.new_batch_trial().add_generator_run(sobol_run).run()
# Check that factory generates a valid GP+EI modelbridge.
exp.optimization_config = get_branin_optimization_config()
gpei = Models.GPEI(experiment=exp, data=exp.fetch_data())
gpei = Models.LEGACY_BOTORCH(experiment=exp, data=exp.fetch_data())
self.assertIsInstance(gpei, TorchModelBridge)
self.assertEqual(gpei._model_key, "GPEI")
self.assertEqual(gpei._model_key, "Legacy_GPEI")
botorch_defaults = "ax.models.torch.botorch_defaults"
# Check that the callable kwargs and the torch kwargs were recorded.
self.assertEqual(
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_enum_sobol_GPEI(self) -> None:
},
)
prior_kwargs = {"lengthscale_prior": GammaPrior(6.0, 6.0)}
gpei = Models.GPEI(
gpei = Models.LEGACY_BOTORCH(
experiment=exp,
data=exp.fetch_data(),
search_space=exp.search_space,
Expand Down Expand Up @@ -316,10 +316,10 @@ def test_get_model_from_generator_run(self) -> None:
self.assertEqual(initial_sobol.gen(n=1).arms, sobol_after_gen.gen(n=1).arms)
exp.new_trial(generator_run=gr)
# Check restoration of GPEI, to ensure proper restoration of callable kwargs
gpei = Models.GPEI(experiment=exp, data=get_branin_data())
gpei = Models.LEGACY_BOTORCH(experiment=exp, data=get_branin_data())
# Punch GPEI model + bridge kwargs into the Sobol generator run, to avoid
# a slow call to `gpei.gen`, and remove Sobol's model state.
gr._model_key = "GPEI"
gr._model_key = "Legacy_GPEI"
gr._model_kwargs = gpei._model_kwargs
gr._bridge_kwargs = gpei._bridge_kwargs
gr._model_state_after_gen = {}
Expand Down Expand Up @@ -496,6 +496,7 @@ def test_deprecated_models(self) -> None:
same check in a couple different ways.
"""
for old_model_str, new_model in [
("GPEI", Models.BOTORCH_MODULAR),
("FULLYBAYESIAN", Models.SAASBO),
("FULLYBAYESIANMOO", Models.SAASBO),
("FULLYBAYESIAN_MTGP", Models.SAAS_MTGP),
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_robust_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_mars(self) -> None:
def test_unsupported_model(self) -> None:
exp = get_robust_branin_experiment()
with self.assertRaisesRegex(UnsupportedError, "support robust"):
Models.GPEI(
Models.LEGACY_BOTORCH(
experiment=exp,
data=exp.fetch_data(),
).gen(n=1)
5 changes: 2 additions & 3 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class AxClient(WithDBSettingsBase, BestPointMixin, InstantiationBase):
torch_device: An optional `torch.device` object, used to choose the device
used for generating new points for trials. Works only for torch-based
models, such as GPEI. Ignored if a `generation_strategy` is passed in
models, such as MBM. Ignored if a `generation_strategy` is passed in
manually. To specify the device for a custom `generation_strategy`,
pass in `torch_device` as part of `model_kwargs`. See
https://ax.dev/tutorials/generation_strategy.html for a tutorial on
Expand Down Expand Up @@ -1119,8 +1119,7 @@ def get_feature_importances(self, relative: bool = True) -> AxPlotConfig:
logger.info(
f"Model {self.generation_strategy.model} does not implement "
"`feature_importances`, so it cannot be used to generate "
"this plot. Only certain models, specifically GPEI, implement "
"feature importances."
"this plot. Only certain models, implement feature importances."
)

raise ValueError(
Expand Down

0 comments on commit 93c236e

Please sign in to comment.