diff --git a/ax/analysis/plotly/tests/test_insample_effects.py b/ax/analysis/plotly/tests/test_insample_effects.py index 7fa06004beb..ac4683de41b 100644 --- a/ax/analysis/plotly/tests/test_insample_effects.py +++ b/ax/analysis/plotly/tests/test_insample_effects.py @@ -39,12 +39,12 @@ def setUp(self) -> None: transition_criteria=[ MaxTrials( threshold=1, - transition_to="GPEI", + transition_to="MBM", ) ], ), GenerationNode( - node_name="GPEI", + node_name="MBM", model_specs=[ ModelSpec( model_enum=Models.BOTORCH_MODULAR, diff --git a/ax/analysis/plotly/tests/test_predicted_effects.py b/ax/analysis/plotly/tests/test_predicted_effects.py index d5bb7edb4d4..8ff1c4ed850 100644 --- a/ax/analysis/plotly/tests/test_predicted_effects.py +++ b/ax/analysis/plotly/tests/test_predicted_effects.py @@ -44,12 +44,12 @@ def setUp(self) -> None: transition_criteria=[ MaxTrials( threshold=1, - transition_to="GPEI", + transition_to="MBM", ) ], ), GenerationNode( - node_name="GPEI", + node_name="MBM", model_specs=[ ModelSpec( model_enum=Models.BOTORCH_MODULAR, diff --git a/ax/modelbridge/registry.py b/ax/modelbridge/registry.py index 6682297ca5d..9d903ad8c9e 100644 --- a/ax/modelbridge/registry.py +++ b/ax/modelbridge/registry.py @@ -159,7 +159,7 @@ class ModelSetup(NamedTuple): transforms=Cont_X_trans + Y_trans, standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS, ), - "GPEI": ModelSetup( + "Legacy_GPEI": ModelSetup( bridge_class=TorchModelBridge, model_class=BotorchModel, transforms=Cont_X_trans + Y_trans, @@ -427,12 +427,11 @@ class Models(ModelRegistryBase): """ SOBOL = "Sobol" - GPEI = "GPEI" FACTORIAL = "Factorial" SAASBO = "SAASBO" SAAS_MTGP = "SAAS_MTGP" THOMPSON = "Thompson" - LEGACY_BOTORCH = "GPEI" + LEGACY_BOTORCH = "Legacy_GPEI" BOTORCH_MODULAR = "BoTorch" EMPIRICAL_BAYES_THOMPSON = "EB" UNIFORM = "Uniform" @@ -443,6 +442,13 @@ class Models(ModelRegistryBase): ST_MTGP_NEHVI = "ST_MTGP_NEHVI" CONTEXT_SACBO = "Contextual_SACBO" + @classmethod + @property + def GPEI(cls) -> Models: + return _deprecated_model_with_warning( + old_model_str="GPEI", new_model=cls.BOTORCH_MODULAR + ) + @classmethod @property def FULLYBAYESIAN(cls) -> Models: diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index 688813ac32c..418c4bcf889 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -103,8 +103,8 @@ def test_SAASBO(self) -> None: ) @mock_botorch_optimize - def test_enum_sobol_GPEI(self) -> None: - """Tests Sobol and GPEI instantiation through the Models enum.""" + def test_enum_sobol_legacy_GPEI(self) -> None: + """Tests Sobol and Legacy GPEI instantiation through the Models enum.""" exp = get_branin_experiment() # Check that factory generates a valid sobol modelbridge. sobol = Models.SOBOL(search_space=exp.search_space) @@ -115,9 +115,9 @@ def test_enum_sobol_GPEI(self) -> None: exp.new_batch_trial().add_generator_run(sobol_run).run() # Check that factory generates a valid GP+EI modelbridge. exp.optimization_config = get_branin_optimization_config() - gpei = Models.GPEI(experiment=exp, data=exp.fetch_data()) + gpei = Models.LEGACY_BOTORCH(experiment=exp, data=exp.fetch_data()) self.assertIsInstance(gpei, TorchModelBridge) - self.assertEqual(gpei._model_key, "GPEI") + self.assertEqual(gpei._model_key, "Legacy_GPEI") botorch_defaults = "ax.models.torch.botorch_defaults" # Check that the callable kwargs and the torch kwargs were recorded. self.assertEqual( @@ -168,7 +168,7 @@ def test_enum_sobol_GPEI(self) -> None: }, ) prior_kwargs = {"lengthscale_prior": GammaPrior(6.0, 6.0)} - gpei = Models.GPEI( + gpei = Models.LEGACY_BOTORCH( experiment=exp, data=exp.fetch_data(), search_space=exp.search_space, @@ -316,10 +316,10 @@ def test_get_model_from_generator_run(self) -> None: self.assertEqual(initial_sobol.gen(n=1).arms, sobol_after_gen.gen(n=1).arms) exp.new_trial(generator_run=gr) # Check restoration of GPEI, to ensure proper restoration of callable kwargs - gpei = Models.GPEI(experiment=exp, data=get_branin_data()) + gpei = Models.LEGACY_BOTORCH(experiment=exp, data=get_branin_data()) # Punch GPEI model + bridge kwargs into the Sobol generator run, to avoid # a slow call to `gpei.gen`, and remove Sobol's model state. - gr._model_key = "GPEI" + gr._model_key = "Legacy_GPEI" gr._model_kwargs = gpei._model_kwargs gr._bridge_kwargs = gpei._bridge_kwargs gr._model_state_after_gen = {} @@ -496,6 +496,7 @@ def test_deprecated_models(self) -> None: same check in a couple different ways. """ for old_model_str, new_model in [ + ("GPEI", Models.BOTORCH_MODULAR), ("FULLYBAYESIAN", Models.SAASBO), ("FULLYBAYESIANMOO", Models.SAASBO), ("FULLYBAYESIAN_MTGP", Models.SAAS_MTGP), diff --git a/ax/modelbridge/tests/test_robust_modelbridge.py b/ax/modelbridge/tests/test_robust_modelbridge.py index d98fc5c4b44..925da78821b 100644 --- a/ax/modelbridge/tests/test_robust_modelbridge.py +++ b/ax/modelbridge/tests/test_robust_modelbridge.py @@ -129,7 +129,7 @@ def test_mars(self) -> None: def test_unsupported_model(self) -> None: exp = get_robust_branin_experiment() with self.assertRaisesRegex(UnsupportedError, "support robust"): - Models.GPEI( + Models.LEGACY_BOTORCH( experiment=exp, data=exp.fetch_data(), ).gen(n=1) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 905f6258454..dd63d35349f 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -157,7 +157,7 @@ class AxClient(WithDBSettingsBase, BestPointMixin, InstantiationBase): torch_device: An optional `torch.device` object, used to choose the device used for generating new points for trials. Works only for torch-based - models, such as GPEI. Ignored if a `generation_strategy` is passed in + models, such as MBM. Ignored if a `generation_strategy` is passed in manually. To specify the device for a custom `generation_strategy`, pass in `torch_device` as part of `model_kwargs`. See https://ax.dev/tutorials/generation_strategy.html for a tutorial on @@ -1119,8 +1119,7 @@ def get_feature_importances(self, relative: bool = True) -> AxPlotConfig: logger.info( f"Model {self.generation_strategy.model} does not implement " "`feature_importances`, so it cannot be used to generate " - "this plot. Only certain models, specifically GPEI, implement " - "feature importances." + "this plot. Only certain models, implement feature importances." ) raise ValueError(