From 3847400c7dccbce9c04017b08e8cc5f078671dd5 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 4 Dec 2024 12:13:28 -0800 Subject: [PATCH] Update the default set of Ax transforms used in MBM (#3144) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3144 This diff adds a new set of transforms (to be used by default in MBM based models) that replaces `IntToFloat` with `LogIntToFloat` and `UnitX` with `Normalize`. The new set of transforms avoid using continuous relaxation for non log-scale discrete parameter, which consistently delivers improved optimization performance on mixed integer benchmark problems. This diff only updates single task model registry entries. I will follow it up with additional diffs to propagage the changes in multiple stages. Reviewed By: dme65 Differential Revision: D66724547 fbshipit-source-id: 90fe3fb97048588138eb946dbe42a38e6b7481de --- ax/modelbridge/registry.py | 32 +++++++++++++++++++-- ax/modelbridge/tests/test_dispatch_utils.py | 5 ++-- ax/service/tests/test_ax_client.py | 12 ++++++-- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/ax/modelbridge/registry.py b/ax/modelbridge/registry.py index 16fe55cf64b..5d8f8288a4b 100644 --- a/ax/modelbridge/registry.py +++ b/ax/modelbridge/registry.py @@ -40,7 +40,7 @@ from ax.modelbridge.transforms.derelativize import Derelativize from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice -from ax.modelbridge.transforms.int_to_float import IntToFloat +from ax.modelbridge.transforms.int_to_float import IntToFloat, LogIntToFloat from ax.modelbridge.transforms.ivw import IVW from ax.modelbridge.transforms.log import Log from ax.modelbridge.transforms.logit import Logit @@ -76,6 +76,10 @@ logger: Logger = get_logger(__name__) +# This set of transforms uses continuous relaxation to handle discrete parameters. +# All candidate generation is done in the continuous space, and the generated +# candidates are rounded to fit the original search space. This is can be +# suboptimal when there are discrete parameters with a small number of options. Cont_X_trans: list[type[Transform]] = [ FillMissingParameters, RemoveFixed, @@ -87,8 +91,30 @@ UnitX, ] +# This is a modification of Cont_X_trans that aims to avoid continuous relaxation +# where possible. It replaces IntToFloat with LogIntToFloat, which is only transforms +# log-scale integer parameters, which still use continuous relaxation. Other discrete +# transforms will remain discrete. When used with MBM, a Normalize input transform +# will be added to replace the UnitX transform. This setup facilitates the use of +# optimize_acqf_mixed_alternating, which is a more efficient acquisition function +# optimizer for mixed discrete/continuous problems. +MBM_X_trans: list[type[Transform]] = [ + FillMissingParameters, + RemoveFixed, + OrderedChoiceToIntegerRange, + OneHot, + LogIntToFloat, + Log, + Logit, +] + + Discrete_X_trans: list[type[Transform]] = [IntRangeToChoice] +# This is a modification of Cont_X_trans that replaces OneHot and +# OrderedChoiceToIntegerRange with ChoiceToNumericChoice. This results in retaining +# all choice parameters as discrete, while using continuous relaxation for integer +# valued RangeParameters. Mixed_transforms: list[type[Transform]] = [ FillMissingParameters, RemoveFixed, @@ -155,7 +181,7 @@ class ModelSetup(NamedTuple): "BoTorch": ModelSetup( bridge_class=TorchModelBridge, model_class=ModularBoTorchModel, - transforms=Cont_X_trans + Y_trans, + transforms=MBM_X_trans + Y_trans, standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS, ), "Legacy_GPEI": ModelSetup( @@ -204,7 +230,7 @@ class ModelSetup(NamedTuple): "SAASBO": ModelSetup( bridge_class=TorchModelBridge, model_class=ModularBoTorchModel, - transforms=Cont_X_trans + Y_trans, + transforms=MBM_X_trans + Y_trans, default_model_kwargs={ "surrogate_spec": SurrogateSpec( botorch_model_class=SaasFullyBayesianSingleTaskGP diff --git a/ax/modelbridge/tests/test_dispatch_utils.py b/ax/modelbridge/tests/test_dispatch_utils.py index 69827dd1bcc..2e6f8226927 100644 --- a/ax/modelbridge/tests/test_dispatch_utils.py +++ b/ax/modelbridge/tests/test_dispatch_utils.py @@ -19,8 +19,7 @@ choose_generation_strategy, DEFAULT_BAYESIAN_PARALLELISM, ) -from ax.modelbridge.factory import Cont_X_trans, Y_trans -from ax.modelbridge.registry import Mixed_transforms, Models +from ax.modelbridge.registry import MBM_X_trans, Mixed_transforms, Models, Y_trans from ax.modelbridge.transforms.log_y import LogY from ax.modelbridge.transforms.winsorize import Winsorize from ax.models.winsorization_config import WinsorizationConfig @@ -44,7 +43,7 @@ class TestDispatchUtils(TestCase): @mock_botorch_optimize def test_choose_generation_strategy(self) -> None: - expected_transforms = [Winsorize] + Cont_X_trans + Y_trans + expected_transforms = [Winsorize] + MBM_X_trans + Y_trans expected_transform_configs = { "Winsorize": {"derelativize_with_raw_status_quo": False}, "Derelativize": {"use_raw_status_quo": False}, diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index efa67476ed7..4cf6825f326 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -57,9 +57,8 @@ ) from ax.modelbridge.model_spec import ModelSpec from ax.modelbridge.random import RandomModelBridge -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Cont_X_trans, Models from ax.runners.synthetic import SyntheticRunner - from ax.service.ax_client import AxClient, ObjectiveProperties from ax.service.utils.best_point import ( get_best_parameters_from_model_predictions_with_trial_index, @@ -220,7 +219,14 @@ def get_client_with_simple_discrete_moo_problem( gs = GenerationStrategy( steps=[ GenerationStep(model=Models.SOBOL, num_trials=3), - GenerationStep(model=Models.BOTORCH_MODULAR, num_trials=-1), + GenerationStep( + model=Models.BOTORCH_MODULAR, + num_trials=-1, + model_kwargs={ + # To avoid search space exhausted errors. + "transforms": Cont_X_trans, + }, + ), ] )