From 905e2f2b313e6fa42ad10e77d66965071708e4a1 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Wed, 5 Mar 2025 20:50:27 -0800 Subject: [PATCH 1/3] Rename choose_generation_strategy to choose_generation_strategy_legacy Summary: We have a new choose_generation_strategy in ax.preview.modelbridge that will be migrated out of preview in the next diff. Rename existing choose_generation_strategy to avoid name conflict Differential Revision: D70647194 --- .../plotly/tests/test_predicted_effects.py | 8 +- ax/generation_strategy/dispatch_utils.py | 2 +- .../tests/test_dispatch_utils.py | 132 ++++++++++-------- ax/runners/tests/test_torchx.py | 8 +- ax/service/ax_client.py | 4 +- ax/service/managed_loop.py | 4 +- ax/service/tests/scheduler_test_utils.py | 5 +- ax/service/tests/test_best_point_utils.py | 8 +- ax/service/tests/test_scheduler.py | 5 +- ax/storage/sqa_store/tests/test_sqa_store.py | 12 +- ax/utils/testing/modeling_stubs.py | 4 +- 11 files changed, 103 insertions(+), 89 deletions(-) diff --git a/ax/analysis/plotly/tests/test_predicted_effects.py b/ax/analysis/plotly/tests/test_predicted_effects.py index 83cb20ffbe8..0dc9c9df964 100644 --- a/ax/analysis/plotly/tests/test_predicted_effects.py +++ b/ax/analysis/plotly/tests/test_predicted_effects.py @@ -15,7 +15,7 @@ from ax.core.observation import ObservationFeatures from ax.core.trial import Trial from ax.exceptions.core import UserInputError -from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy from ax.modelbridge.prediction_utils import predict_at_point from ax.modelbridge.registry import Generators from ax.utils.common.testutils import TestCase @@ -50,7 +50,7 @@ def test_compute_for_requires_a_gs(self) -> None: def test_compute_for_requires_trials(self) -> None: analysis = PredictedEffectsPlot(metric_name="branin") experiment = get_branin_experiment() - generation_strategy = choose_generation_strategy( + generation_strategy = choose_generation_strategy_legacy( search_space=experiment.search_space, experiment=experiment, ) @@ -62,7 +62,7 @@ def test_compute_for_requires_trials(self) -> None: def test_compute_for_requires_a_model_that_predicts(self) -> None: analysis = PredictedEffectsPlot(metric_name="branin") experiment = get_branin_experiment(with_batch=True, with_completed_batch=True) - generation_strategy = choose_generation_strategy( + generation_strategy = choose_generation_strategy_legacy( search_space=experiment.search_space, experiment=experiment, ) @@ -311,7 +311,7 @@ def test_it_does_not_plot_abandoned_trials(self) -> None: def test_it_works_for_non_batch_experiments(self) -> None: # GIVEN an experiment with the default generation strategy experiment = get_branin_experiment(with_batch=False) - generation_strategy = choose_generation_strategy( + generation_strategy = choose_generation_strategy_legacy( search_space=experiment.search_space, experiment=experiment, ) diff --git a/ax/generation_strategy/dispatch_utils.py b/ax/generation_strategy/dispatch_utils.py index 9ce102785de..402c4d6f82d 100644 --- a/ax/generation_strategy/dispatch_utils.py +++ b/ax/generation_strategy/dispatch_utils.py @@ -294,7 +294,7 @@ def calculate_num_initialization_trials( return max(ret, 5) -def choose_generation_strategy( +def choose_generation_strategy_legacy( search_space: SearchSpace, *, use_batch_trials: bool = False, diff --git a/ax/generation_strategy/tests/test_dispatch_utils.py b/ax/generation_strategy/tests/test_dispatch_utils.py index 9e3bd2bc2af..38e274a189f 100644 --- a/ax/generation_strategy/tests/test_dispatch_utils.py +++ b/ax/generation_strategy/tests/test_dispatch_utils.py @@ -16,7 +16,7 @@ from ax.generation_strategy.dispatch_utils import ( _make_botorch_step, calculate_num_initialization_trials, - choose_generation_strategy, + choose_generation_strategy_legacy, DEFAULT_BAYESIAN_PARALLELISM, ) from ax.modelbridge.registry import Generators, MBM_X_trans, Mixed_transforms, Y_trans @@ -43,14 +43,14 @@ class TestDispatchUtils(TestCase): """Tests that dispatching utilities correctly select generation strategies.""" @mock_botorch_optimize - def test_choose_generation_strategy(self) -> None: + def test_choose_generation_strategy_legacy(self) -> None: expected_transforms = [Winsorize] + MBM_X_trans + Y_trans expected_transform_configs = { "Winsorize": {"derelativize_with_raw_status_quo": False}, "Derelativize": {"use_raw_status_quo": False}, } with self.subTest("GPEI"): - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space() ) self.assertEqual(sobol_gpei._steps[0].model, Generators.SOBOL) @@ -64,7 +64,7 @@ def test_choose_generation_strategy(self) -> None: } self.assertEqual(sobol_gpei._steps[1].model_kwargs, expected_model_kwargs) device = torch.device("cpu") - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), verbose=True, torch_device=device, @@ -75,7 +75,7 @@ def test_choose_generation_strategy(self) -> None: generation_strategy=sobol_gpei ) with self.subTest("max initialization trials"): - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), max_initialization_trials=2, ) @@ -83,7 +83,7 @@ def test_choose_generation_strategy(self) -> None: self.assertEqual(sobol_gpei._steps[0].num_trials, 2) self.assertEqual(sobol_gpei._steps[1].model, Generators.BOTORCH_MODULAR) with self.subTest("min sobol trials"): - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), min_sobol_trials_observed=1, ) @@ -91,7 +91,7 @@ def test_choose_generation_strategy(self) -> None: self.assertEqual(sobol_gpei._steps[0].min_trials_observed, 1) self.assertEqual(sobol_gpei._steps[1].model, Generators.BOTORCH_MODULAR) with self.subTest("num_initialization_trials > max_initialization_trials"): - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), max_initialization_trials=2, num_initialization_trials=3, @@ -100,7 +100,7 @@ def test_choose_generation_strategy(self) -> None: self.assertEqual(sobol_gpei._steps[0].num_trials, 3) self.assertEqual(sobol_gpei._steps[1].model, Generators.BOTORCH_MODULAR) with self.subTest("num_initialization_trials > max_initialization_trials"): - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), max_initialization_trials=2, num_initialization_trials=3, @@ -112,7 +112,7 @@ def test_choose_generation_strategy(self) -> None: optimization_config = MultiObjectiveOptimizationConfig( objective=MultiObjective(objectives=[]) ) - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), optimization_config=optimization_config, ) @@ -131,22 +131,22 @@ def test_choose_generation_strategy(self) -> None: ) self.assertGreater(len(model_kwargs["transforms"]), 0) with self.subTest("Sobol (we can try every option)"): - sobol = choose_generation_strategy( + sobol = choose_generation_strategy_legacy( search_space=get_factorial_search_space(), num_trials=1000 ) self.assertEqual(sobol._steps[0].model, Generators.SOBOL) self.assertEqual(len(sobol._steps), 1) with self.subTest("Sobol (because of too many categories)"): - sobol_large = choose_generation_strategy( + sobol_large = choose_generation_strategy_legacy( search_space=get_large_factorial_search_space(), verbose=True ) self.assertEqual(sobol_large._steps[0].model, Generators.SOBOL) self.assertEqual(len(sobol_large._steps), 1) with self.subTest("Sobol (because of too many categories) with saasbo"): with self.assertLogs( - choose_generation_strategy.__module__, logging.WARNING + choose_generation_strategy_legacy.__module__, logging.WARNING ) as logger: - sobol_large = choose_generation_strategy( + sobol_large = choose_generation_strategy_legacy( search_space=get_large_factorial_search_space(), verbose=True, use_saasbo=True, @@ -162,7 +162,7 @@ def test_choose_generation_strategy(self) -> None: self.assertEqual(len(sobol_large._steps), 1) with self.subTest("SOBOL due to too many unordered choices"): # Search space with more unordered choices than ordered parameters. - sobol = choose_generation_strategy( + sobol = choose_generation_strategy_legacy( search_space=get_search_space_with_choice_parameters( num_ordered_parameters=5, num_unordered_choices=100, @@ -172,7 +172,7 @@ def test_choose_generation_strategy(self) -> None: self.assertEqual(len(sobol._steps), 1) with self.subTest("GPEI with more unordered choices than ordered parameters"): # Search space with more unordered choices than ordered parameters. - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_search_space_with_choice_parameters( num_ordered_parameters=5, num_unordered_choices=10, @@ -180,7 +180,7 @@ def test_choose_generation_strategy(self) -> None: ) self.assertEqual(sobol_gpei._steps[1].model, Generators.BOTORCH_MODULAR) with self.subTest("GPEI despite many unordered 2-value parameters"): - gs = choose_generation_strategy( + gs = choose_generation_strategy_legacy( search_space=get_large_factorial_search_space( num_levels=2, num_parameters=10 ), @@ -188,13 +188,13 @@ def test_choose_generation_strategy(self) -> None: self.assertEqual(gs._steps[0].model, Generators.SOBOL) self.assertEqual(gs._steps[1].model, Generators.BOTORCH_MODULAR) with self.subTest("GPEI-Batched"): - sobol_gpei_batched = choose_generation_strategy( + sobol_gpei_batched = choose_generation_strategy_legacy( search_space=get_branin_search_space(), use_batch_trials=True, ) self.assertEqual(sobol_gpei_batched._steps[0].num_trials, 1) with self.subTest("BO_MIXED (purely categorical)"): - bo_mixed = choose_generation_strategy( + bo_mixed = choose_generation_strategy_legacy( search_space=get_factorial_search_space() ) self.assertEqual(bo_mixed._steps[0].model, Generators.SOBOL) @@ -211,7 +211,7 @@ def test_choose_generation_strategy(self) -> None: ss = get_branin_search_space(with_choice_parameter=True) # pyre-fixme[16]: `Parameter` has no attribute `_is_ordered`. ss.parameters["x2"]._is_ordered = False - bo_mixed_2 = choose_generation_strategy(search_space=ss) + bo_mixed_2 = choose_generation_strategy_legacy(search_space=ss) self.assertEqual(bo_mixed_2._steps[0].model, Generators.SOBOL) self.assertEqual(bo_mixed_2._steps[0].num_trials, 5) self.assertEqual(bo_mixed_2._steps[1].model, Generators.BO_MIXED) @@ -228,7 +228,7 @@ def test_choose_generation_strategy(self) -> None: optimization_config = MultiObjectiveOptimizationConfig( objective=MultiObjective(objectives=[]) ) - moo_mixed = choose_generation_strategy( + moo_mixed = choose_generation_strategy_legacy( search_space=search_space, optimization_config=optimization_config ) self.assertEqual(moo_mixed._steps[0].model, Generators.SOBOL) @@ -246,7 +246,7 @@ def test_choose_generation_strategy(self) -> None: ) self.assertGreater(len(model_kwargs["transforms"]), 0) with self.subTest("SAASBO"): - sobol_fullybayesian = choose_generation_strategy( + sobol_fullybayesian = choose_generation_strategy_legacy( search_space=get_branin_search_space(), use_batch_trials=True, num_initialization_trials=3, @@ -256,7 +256,7 @@ def test_choose_generation_strategy(self) -> None: self.assertEqual(sobol_fullybayesian._steps[0].num_trials, 3) self.assertEqual(sobol_fullybayesian._steps[1].model, Generators.SAASBO) with self.subTest("SAASBO MOO"): - sobol_fullybayesianmoo = choose_generation_strategy( + sobol_fullybayesianmoo = choose_generation_strategy_legacy( search_space=get_branin_search_space(), use_batch_trials=True, num_initialization_trials=3, @@ -272,7 +272,7 @@ def test_choose_generation_strategy(self) -> None: Generators.SAASBO, ) with self.subTest("SAASBO"): - sobol_fullybayesian_large = choose_generation_strategy( + sobol_fullybayesian_large = choose_generation_strategy_legacy( search_space=get_large_ordinal_search_space( n_ordinal_choice_parameters=5, n_continuous_range_parameters=10 ), @@ -291,7 +291,7 @@ def test_choose_generation_strategy(self) -> None: for _, param in ss.parameters.items(): param._is_ordered = True # 2 * len(ss.parameters) init trials are performed if num_trials is large - gs_12_init_trials = choose_generation_strategy( + gs_12_init_trials = choose_generation_strategy_legacy( search_space=ss, num_trials=100 ) self.assertEqual(gs_12_init_trials._steps[0].model, Generators.SOBOL) @@ -300,7 +300,9 @@ def test_choose_generation_strategy(self) -> None: gs_12_init_trials._steps[1].model, Generators.BOTORCH_MODULAR ) # at least 5 initialization trials are performed - gs_5_init_trials = choose_generation_strategy(search_space=ss, num_trials=0) + gs_5_init_trials = choose_generation_strategy_legacy( + search_space=ss, num_trials=0 + ) self.assertEqual(gs_5_init_trials._steps[0].model, Generators.SOBOL) self.assertEqual(gs_5_init_trials._steps[0].num_trials, 5) self.assertEqual( @@ -308,7 +310,7 @@ def test_choose_generation_strategy(self) -> None: ) # avoid spending >20% of budget on initialization trials if there are # more than 5 initialization trials - gs_6_init_trials = choose_generation_strategy( + gs_6_init_trials = choose_generation_strategy_legacy( search_space=ss, num_trials=30 ) self.assertEqual(gs_6_init_trials._steps[0].model, Generators.SOBOL) @@ -317,11 +319,11 @@ def test_choose_generation_strategy(self) -> None: gs_6_init_trials._steps[1].model, Generators.BOTORCH_MODULAR ) with self.subTest("suggested_model_override"): - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space() ) self.assertEqual(sobol_gpei._steps[1].model, Generators.BOTORCH_MODULAR) - sobol_saasbo = choose_generation_strategy( + sobol_saasbo = choose_generation_strategy_legacy( search_space=get_branin_search_space(), suggested_model_override=Generators.SAASBO, ) @@ -351,7 +353,7 @@ def test_make_botorch_step_extra(self) -> None: def test_disable_progbar(self) -> None: for disable_progbar in (True, False): with self.subTest(str(disable_progbar)): - sobol_saasbo = choose_generation_strategy( + sobol_saasbo = choose_generation_strategy_legacy( search_space=get_branin_search_space(), disable_progbar=disable_progbar, use_saasbo=True, @@ -381,7 +383,7 @@ def test_disable_progbar(self) -> None: def test_disable_progbar_for_non_saasbo_discards_the_model_kwarg(self) -> None: for disable_progbar in (True, False): with self.subTest(str(disable_progbar)): - gp_saasbo = choose_generation_strategy( + gp_saasbo = choose_generation_strategy_legacy( search_space=get_branin_search_space(), disable_progbar=disable_progbar, use_saasbo=False, @@ -402,7 +404,7 @@ def test_disable_progbar_for_non_saasbo_discards_the_model_kwarg(self) -> None: ) def test_setting_random_seed(self) -> None: - sobol = choose_generation_strategy( + sobol = choose_generation_strategy_legacy( search_space=get_factorial_search_space(), random_seed=9 ) sobol.gen(experiment=get_experiment(), n=1) @@ -413,9 +415,9 @@ def test_setting_random_seed(self) -> None: with self.subTest("warns if use_saasbo is true"): with self.assertLogs( - choose_generation_strategy.__module__, logging.WARNING + choose_generation_strategy_legacy.__module__, logging.WARNING ) as logger: - sobol = choose_generation_strategy( + sobol = choose_generation_strategy_legacy( search_space=get_factorial_search_space(), random_seed=9, use_saasbo=True, @@ -430,14 +432,14 @@ def test_setting_random_seed(self) -> None: def test_enforce_sequential_optimization(self) -> None: with self.subTest("True"): - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space() ) self.assertEqual(sobol_gpei._steps[0].num_trials, 5) self.assertTrue(sobol_gpei._steps[0].enforce_num_trials) self.assertIsNotNone(sobol_gpei._steps[1].max_parallelism) with self.subTest("False"): - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), enforce_sequential_optimization=False, ) @@ -446,9 +448,9 @@ def test_enforce_sequential_optimization(self) -> None: self.assertIsNone(sobol_gpei._steps[1].max_parallelism) with self.subTest("False and max_parallelism_override"): with self.assertLogs( - choose_generation_strategy.__module__, logging.INFO + choose_generation_strategy_legacy.__module__, logging.INFO ) as logger: - choose_generation_strategy( + choose_generation_strategy_legacy( search_space=get_branin_search_space(), enforce_sequential_optimization=False, max_parallelism_override=5, @@ -462,9 +464,9 @@ def test_enforce_sequential_optimization(self) -> None: ) with self.subTest("False and max_parallelism_cap"): with self.assertLogs( - choose_generation_strategy.__module__, logging.INFO + choose_generation_strategy_legacy.__module__, logging.INFO ) as logger: - choose_generation_strategy( + choose_generation_strategy_legacy( search_space=get_branin_search_space(), enforce_sequential_optimization=False, max_parallelism_cap=5, @@ -484,7 +486,7 @@ def test_enforce_sequential_optimization(self) -> None: "`max_parallelism_cap`." ), ): - choose_generation_strategy( + choose_generation_strategy_legacy( search_space=get_branin_search_space(), enforce_sequential_optimization=False, max_parallelism_override=5, @@ -492,13 +494,13 @@ def test_enforce_sequential_optimization(self) -> None: ) def test_max_parallelism_override(self) -> None: - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), max_parallelism_override=10 ) self.assertTrue(all(s.max_parallelism == 10 for s in sobol_gpei._steps)) def test_winsorization(self) -> None: - winsorized = choose_generation_strategy( + winsorized = choose_generation_strategy_legacy( search_space=get_branin_search_space(), winsorization_config=WinsorizationConfig(upper_quantile_margin=2), ) @@ -518,7 +520,7 @@ def test_winsorization(self) -> None: self.assertIn("Derelativize", tc) self.assertDictEqual(tc["Derelativize"], {"use_raw_status_quo": False}) - winsorized = choose_generation_strategy( + winsorized = choose_generation_strategy_legacy( search_space=get_branin_search_space(), derelativize_with_raw_status_quo=True, ) @@ -539,7 +541,7 @@ def test_winsorization(self) -> None: def test_no_winzorization_wins(self) -> None: with warnings.catch_warnings(record=True) as w: - unwinsorized = choose_generation_strategy( + unwinsorized = choose_generation_strategy_legacy( search_space=get_branin_search_space(), winsorization_config=WinsorizationConfig(upper_quantile_margin=2), no_winsorization=True, @@ -557,18 +559,20 @@ def test_num_trials(self) -> None: with self.subTest( "with budget that is lower than exhaustive, BayesOpt is used" ): - sobol_gpei = choose_generation_strategy(search_space=ss, num_trials=23) + sobol_gpei = choose_generation_strategy_legacy( + search_space=ss, num_trials=23 + ) self.assertEqual(sobol_gpei._steps[0].model, Generators.SOBOL) self.assertEqual(sobol_gpei._steps[1].model, Generators.BO_MIXED) with self.subTest("with budget that is exhaustive, Sobol is used"): - sobol = choose_generation_strategy(search_space=ss, num_trials=36) + sobol = choose_generation_strategy_legacy(search_space=ss, num_trials=36) self.assertEqual(sobol._steps[0].model, Generators.SOBOL) self.assertEqual(len(sobol._steps), 1) with self.subTest("with budget that is exhaustive and use_saasbo, it warns"): with self.assertLogs( - choose_generation_strategy.__module__, logging.WARNING + choose_generation_strategy_legacy.__module__, logging.WARNING ) as logger: - sobol = choose_generation_strategy( + sobol = choose_generation_strategy_legacy( search_space=ss, num_trials=36, use_saasbo=True, @@ -584,13 +588,13 @@ def test_num_trials(self) -> None: self.assertEqual(len(sobol._steps), 1) def test_use_batch_trials(self) -> None: - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), use_batch_trials=True ) self.assertEqual(sobol_gpei._steps[0].num_trials, 1) def test_fixed_num_initialization_trials(self) -> None: - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), use_batch_trials=True, num_initialization_trials=3, @@ -599,13 +603,15 @@ def test_fixed_num_initialization_trials(self) -> None: def test_max_parallelism_adjustments(self) -> None: # No adjustment. - sobol_gpei = choose_generation_strategy(search_space=get_branin_search_space()) + sobol_gpei = choose_generation_strategy_legacy( + search_space=get_branin_search_space() + ) self.assertIsNone(sobol_gpei._steps[0].max_parallelism) self.assertEqual( sobol_gpei._steps[1].max_parallelism, DEFAULT_BAYESIAN_PARALLELISM ) # Impose a cap of 1 on max parallelism for all steps. - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), max_parallelism_cap=1 ) self.assertEqual( @@ -617,20 +623,20 @@ def test_max_parallelism_adjustments(self) -> None: 1, ) # Disable enforcing max parallelism for all steps. - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), max_parallelism_override=-1 ) self.assertIsNone(sobol_gpei._steps[0].max_parallelism) self.assertIsNone(sobol_gpei._steps[1].max_parallelism) # Override max parallelism for all steps. - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), max_parallelism_override=10 ) self.assertEqual(sobol_gpei._steps[0].max_parallelism, 10) self.assertEqual(sobol_gpei._steps[1].max_parallelism, 10) def test_set_should_deduplicate(self) -> None: - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), use_batch_trials=True, num_initialization_trials=3, @@ -638,7 +644,7 @@ def test_set_should_deduplicate(self) -> None: self.assertListEqual( [s.should_deduplicate for s in sobol_gpei._steps], [False] * 2 ) - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), use_batch_trials=True, num_initialization_trials=3, @@ -650,19 +656,23 @@ def test_set_should_deduplicate(self) -> None: def test_setting_experiment_attribute(self) -> None: exp = get_experiment() - gs = choose_generation_strategy(search_space=exp.search_space, experiment=exp) + gs = choose_generation_strategy_legacy( + search_space=exp.search_space, experiment=exp + ) self.assertEqual(gs._experiment, exp) def test_setting_num_completed_initialization_trials(self) -> None: default_initialization_num_trials = 5 - sobol_gpei = choose_generation_strategy(search_space=get_branin_search_space()) + sobol_gpei = choose_generation_strategy_legacy( + search_space=get_branin_search_space() + ) self.assertEqual( sobol_gpei._steps[0].num_trials, default_initialization_num_trials ) num_completed_initialization_trials = 2 - sobol_gpei = choose_generation_strategy( + sobol_gpei = choose_generation_strategy_legacy( search_space=get_branin_search_space(), num_completed_initialization_trials=num_completed_initialization_trials, ) @@ -737,7 +747,7 @@ def test_calculate_num_initialization_trials(self) -> None: def test_jit_compile(self) -> None: for jit_compile in (True, False): with self.subTest(str(jit_compile)): - sobol_saasbo = choose_generation_strategy( + sobol_saasbo = choose_generation_strategy_legacy( search_space=get_branin_search_space(), jit_compile=jit_compile, use_saasbo=True, @@ -767,7 +777,7 @@ def test_jit_compile(self) -> None: def test_jit_compile_for_non_saasbo_discards_the_model_kwarg(self) -> None: for jit_compile in (True, False): with self.subTest(str(jit_compile)): - gp_saasbo = choose_generation_strategy( + gp_saasbo = choose_generation_strategy_legacy( search_space=get_branin_search_space(), jit_compile=jit_compile, use_saasbo=False, diff --git a/ax/runners/tests/test_torchx.py b/ax/runners/tests/test_torchx.py index 00e8642e67b..d1c35ddc9ef 100644 --- a/ax/runners/tests/test_torchx.py +++ b/ax/runners/tests/test_torchx.py @@ -21,7 +21,7 @@ SearchSpace, ) -from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy from ax.metrics.torchx import TorchXMetric from ax.runners.torchx import TorchXRunner from ax.service.scheduler import FailureRateExceededError, Scheduler, SchedulerOptions @@ -87,7 +87,7 @@ def test_run_experiment_locally(self) -> None: scheduler = Scheduler( experiment=experiment, generation_strategy=( - choose_generation_strategy( + choose_generation_strategy_legacy( search_space=experiment.search_space, ) ), @@ -117,7 +117,7 @@ def test_stop_trials(self) -> None: scheduler = Scheduler( experiment=experiment, generation_strategy=( - choose_generation_strategy( + choose_generation_strategy_legacy( search_space=experiment.search_space, ) ), @@ -155,7 +155,7 @@ def test_run_experiment_locally_in_batches(self) -> None: scheduler = Scheduler( experiment=experiment, generation_strategy=( - choose_generation_strategy( + choose_generation_strategy_legacy( search_space=experiment.search_space, max_parallelism_cap=parallelism, ) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index c8aa8d13715..a39b6aa21c7 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -55,7 +55,7 @@ UserInputError, ) from ax.exceptions.generation_strategy import MaxParallelismReachedException -from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy from ax.global_stopping.strategies.improvement import constraint_satisfaction @@ -1768,7 +1768,7 @@ def _set_generation_strategy( "enforce_sequential_optimization", self._enforce_sequential_optimization ) if self._generation_strategy is None: - self._generation_strategy = choose_generation_strategy( + self._generation_strategy = choose_generation_strategy_legacy( search_space=self.experiment.search_space, optimization_config=self.experiment.optimization_config, enforce_sequential_optimization=enforce_sequential_optimization, diff --git a/ax/service/managed_loop.py b/ax/service/managed_loop.py index c5b127e3405..94c6c152fa5 100644 --- a/ax/service/managed_loop.py +++ b/ax/service/managed_loop.py @@ -27,7 +27,7 @@ from ax.core.utils import get_pending_observation_features from ax.exceptions.constants import CHOLESKY_ERROR_ANNOTATION from ax.exceptions.core import SearchSpaceExhausted, UserInputError -from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.base import Adapter from ax.service.utils.best_point import ( @@ -75,7 +75,7 @@ def __init__( self.experiment = experiment if generation_strategy is None: # pyre-fixme[4]: Attribute must be annotated. - self.generation_strategy = choose_generation_strategy( + self.generation_strategy = choose_generation_strategy_legacy( search_space=experiment.search_space, use_batch_trials=self.arms_per_trial > 1, random_seed=self.random_seed, diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 6e76820b92e..e4c243f555b 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 + # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the @@ -43,7 +44,7 @@ UserInputError, ) from ax.exceptions.generation_strategy import AxGenerationException -from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy from ax.generation_strategy.generation_strategy import ( GenerationStep, GenerationStrategy, @@ -352,7 +353,7 @@ def setUp(self) -> None: ), name="branin_experiment_no_impl_runner_or_metrics", ) - self.sobol_MBM_GS = choose_generation_strategy( + self.sobol_MBM_GS = choose_generation_strategy_legacy( search_space=get_branin_search_space() ) self.two_sobol_steps_GS = GenerationStrategy( # Contrived GS to ensure diff --git a/ax/service/tests/test_best_point_utils.py b/ax/service/tests/test_best_point_utils.py index fe21340ce8b..6bfd43a08e5 100644 --- a/ax/service/tests/test_best_point_utils.py +++ b/ax/service/tests/test_best_point_utils.py @@ -24,7 +24,7 @@ from ax.core.outcome_constraint import OutcomeConstraint from ax.core.types import ComparisonOp from ax.exceptions.core import UserInputError -from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy from ax.modelbridge.cross_validation import AssessModelFitResult from ax.modelbridge.registry import Generators from ax.modelbridge.torch import TorchAdapter @@ -65,7 +65,7 @@ class TestBestPointUtils(TestCase): @mock_botorch_optimize def test_best_from_model_prediction(self) -> None: exp = get_branin_experiment() - gs = choose_generation_strategy( + gs = choose_generation_strategy_legacy( search_space=exp.search_space, num_initialization_trials=3, suggested_model_override=Generators.BOTORCH_MODULAR, @@ -255,7 +255,7 @@ def test_best_raw_objective_point_unsatisfiable_relative(self) -> None: def test_best_raw_objective_point_scalarized(self) -> None: exp = get_branin_experiment() - gs = choose_generation_strategy(search_space=exp.search_space) + gs = choose_generation_strategy_legacy(search_space=exp.search_space) exp.optimization_config = OptimizationConfig( ScalarizedObjective(metrics=[get_branin_metric()], minimize=True) ) @@ -277,7 +277,7 @@ def test_best_raw_objective_point_scalarized(self) -> None: def test_best_raw_objective_point_scalarized_multi(self) -> None: exp = get_branin_experiment() - gs = choose_generation_strategy(search_space=exp.search_space) + gs = choose_generation_strategy_legacy(search_space=exp.search_space) exp.optimization_config = OptimizationConfig( ScalarizedObjective( metrics=[get_branin_metric(), get_branin_metric(lower_is_better=False)], diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index c7d3941889c..35ae951b433 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. + # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. @@ -9,7 +10,7 @@ from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig -from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy from ax.generation_strategy.generation_strategy import ( GenerationStep, GenerationStrategy, @@ -103,7 +104,7 @@ def setUp(self) -> None: default_runner=None, name="branin_experiment_no_impl_runner_or_metrics", ) - self.sobol_MBM_GS = choose_generation_strategy( + self.sobol_MBM_GS = choose_generation_strategy_legacy( search_space=get_branin_search_space() ) self.two_sobol_steps_GS = GenerationStrategy( # Contrived GS to ensure diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index e6d5fed32a9..1f9e38bfa44 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -34,7 +34,7 @@ from ax.core.types import ComparisonOp from ax.exceptions.core import ObjectNotFoundError from ax.exceptions.storage import JSONDecodeError, SQADecodeError, SQAEncodeError -from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy from ax.metrics.branin import BraninMetric from ax.modelbridge.registry import Generators from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec @@ -1867,7 +1867,7 @@ def test_GeneratorRunGenMetadata(self) -> None: def test_UpdateGenerationStrategyIncrementally(self) -> None: experiment = get_branin_experiment() - generation_strategy = choose_generation_strategy(experiment.search_space) + generation_strategy = choose_generation_strategy_legacy(experiment.search_space) save_experiment(experiment=experiment) save_generation_strategy(generation_strategy=generation_strategy) @@ -2258,7 +2258,7 @@ def test_AnalysisCard(self) -> None: def test_delete_generation_strategy(self) -> None: # GIVEN an experiment with a generation strategy experiment = get_branin_experiment() - generation_strategy = choose_generation_strategy(experiment.search_space) + generation_strategy = choose_generation_strategy_legacy(experiment.search_space) generation_strategy.experiment = experiment save_experiment(experiment) save_generation_strategy(generation_strategy=generation_strategy) @@ -2266,7 +2266,9 @@ def test_delete_generation_strategy(self) -> None: # AND GIVEN another experiment with a generation strategy experiment2 = get_branin_experiment() experiment2.name = "experiment2" - generation_strategy2 = choose_generation_strategy(experiment2.search_space) + generation_strategy2 = choose_generation_strategy_legacy( + experiment2.search_space + ) generation_strategy2.experiment = experiment2 save_experiment(experiment2) save_generation_strategy(generation_strategy=generation_strategy2) @@ -2288,7 +2290,7 @@ def test_delete_generation_strategy(self) -> None: def test_delete_generation_strategy_max_gs_to_delete(self) -> None: # GIVEN an experiment with a generation strategy experiment = get_branin_experiment() - generation_strategy = choose_generation_strategy(experiment.search_space) + generation_strategy = choose_generation_strategy_legacy(experiment.search_space) generation_strategy.experiment = experiment save_experiment(experiment) save_generation_strategy(generation_strategy=generation_strategy) diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index 710ce08ab37..2ae51d33318 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -21,7 +21,7 @@ ReductionCriterion, SingleDiagnosticBestModelSelector, ) -from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy from ax.generation_strategy.generation_node import GenerationNode from ax.generation_strategy.generation_node_input_constructors import ( @@ -204,7 +204,7 @@ def get_generation_strategy( get_sobol ) else: - gs = choose_generation_strategy( + gs = choose_generation_strategy_legacy( search_space=get_search_space(), should_deduplicate=True ) if with_callable_model_kwarg: From bdebffba2a9b18249ac8f99aebdc6fc30abb361b Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Wed, 5 Mar 2025 20:50:27 -0800 Subject: [PATCH 2/3] Move new dispatch utils out of preview Summary: As titled. Also refactored slightly such that we wont be importing from ax.api anywhere in the codebase. To keep our module structure easy to reason about it is very important to keep the ax.api module at the root of our dep tree. Differential Revision: D70647193 --- ax/generation_strategy/dispatch_utils.py | 203 ++++++++++++++++++ .../tests/test_dispatch_utils.py | 167 ++++++++++++++ ax/preview/api/client.py | 18 +- ax/preview/api/configs.py | 27 +-- ax/preview/modelbridge/__init__.py | 5 - ax/preview/modelbridge/dispatch_utils.py | 171 --------------- .../tests/test_preview_dispatch_utils.py | 189 ---------------- 7 files changed, 388 insertions(+), 392 deletions(-) delete mode 100644 ax/preview/modelbridge/__init__.py delete mode 100644 ax/preview/modelbridge/dispatch_utils.py delete mode 100644 ax/preview/modelbridge/tests/test_preview_dispatch_utils.py diff --git a/ax/generation_strategy/dispatch_utils.py b/ax/generation_strategy/dispatch_utils.py index 402c4d6f82d..5fe7c2c6cd4 100644 --- a/ax/generation_strategy/dispatch_utils.py +++ b/ax/generation_strategy/dispatch_utils.py @@ -8,6 +8,7 @@ import logging import warnings +from enum import Enum from math import ceil from typing import Any, cast @@ -16,10 +17,15 @@ from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.search_space import SearchSpace +from ax.core.trial_status import TrialStatus +from ax.exceptions.core import UnsupportedError from ax.generation_strategy.generation_strategy import ( + GenerationNode, GenerationStep, GenerationStrategy, ) +from ax.generation_strategy.model_spec import GeneratorSpec +from ax.generation_strategy.transition_criterion import MinTrials from ax.modelbridge.registry import ( Generators, MODEL_KEY_TO_MODEL_SETUP, @@ -30,10 +36,13 @@ from ax.models.torch.botorch_modular.model import ( BoTorchGenerator as ModularBoTorchGenerator, ) +from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec from ax.models.types import TConfig from ax.models.winsorization_config import WinsorizationConfig from ax.utils.common.deprecation import _validate_force_random_search from ax.utils.common.logger import get_logger +from botorch.models.transforms.input import Normalize, Warp +from gpytorch.kernels.linear_kernel import LinearKernel from pyre_extensions import none_throws @@ -54,6 +63,200 @@ ) +class GenerationMethod(Enum): + """An enum to specify the desired candidate generation method for the experiment. + This is used in ``GenerationStrategyConfig``, along with the properties of the + experiment, to determine the generation strategy to use for candidate generation. + + NOTE: New options should be rarely added to this enum. This is not intended to be + a list of generation strategies for the user to choose from. Instead, this enum + should only provide high level guidance to the underlying generation strategy + dispatch logic, which is responsible for determinining the exact details. + + Available options are: + BALANCED: A balanced generation method that may utilize (per-metric) model + selection to achieve a good model accuracy. This method excludes expensive + methods, such as the fully Bayesian SAASBO model. Used by default. + FAST: A faster generation method that uses the built-in defaults from the + Modular BoTorch Model without any model selection. + RANDOM_SEARCH: Primarily intended for pure exploration experiments, this + method utilizes quasi-random Sobol sequences for candidate generation. + """ + + BALANCED = "balanced" + FAST = "fast" + RANDOM_SEARCH = "random_search" + + +def _get_sobol_node( + initialization_budget: int | None = None, + initialization_random_seed: int | None = None, + use_existing_trials_for_initialization: bool = True, + min_observed_initialization_trials: int | None = None, + allow_exceeding_initialization_budget: bool = False, +) -> GenerationNode: + """Constructs a Sobol node based on inputs from ``gs_config``. + The Sobol generator utilizes `initialization_random_seed` if specified. + + This node always transitions to "MBM", using the following transition criteria: + - MinTrials enforcing the initialization budget. + - If the initialization budget is not specified, it defaults to 5. + - The TC will not block generation if `allow_exceeding_initialization_budget` + is set to True. + - The TC is currently not restricted to any trial statuses and will + count all trials. + - `use_existing_trials_for_initialization` controls whether trials previously + attached to the experiment are counted as part of the initialization budget. + - MinTrials enforcing the minimum number of observed initialization trials. + - If `min_observed_initialization_trials` is not specified, it defaults + to `max(1, initialization_budget // 2)`. + - The TC currently only counts trials in status COMPLETED (with data attached) + as observed trials. + - `use_existing_trials_for_initialization` controls whether trials previously + attached to the experiment are counted as part of the required number of + observed initialization trials. + """ + # Set the default options. + if initialization_budget is None: + initialization_budget = 5 + if min_observed_initialization_trials is None: + min_observed_initialization_trials = max(1, initialization_budget // 2) + # Construct the transition criteria. + transition_criteria = [ + MinTrials( # This represents the initialization budget. + threshold=initialization_budget, + transition_to="MBM", + block_gen_if_met=(not allow_exceeding_initialization_budget), + block_transition_if_unmet=True, + use_all_trials_in_exp=use_existing_trials_for_initialization, + ), + MinTrials( # This represents minimum observed trials requirement. + threshold=min_observed_initialization_trials, + transition_to="MBM", + block_gen_if_met=False, + block_transition_if_unmet=True, + use_all_trials_in_exp=use_existing_trials_for_initialization, + only_in_statuses=[TrialStatus.COMPLETED], + count_only_trials_with_data=True, + ), + ] + return GenerationNode( + node_name="Sobol", + model_specs=[ + GeneratorSpec( + model_enum=Generators.SOBOL, + model_kwargs={"seed": initialization_random_seed}, + ) + ], + transition_criteria=transition_criteria, + should_deduplicate=True, + ) + + +def _get_mbm_node( + method: GenerationMethod = GenerationMethod.FAST, + torch_device: str | None = None, +) -> GenerationNode: + """Constructs an MBM node based on the method specified in ``gs_config``. + + The ``SurrogateSpec`` takes the following form for the given method: + - BALANCED: Two model configs: one with MBM defaults, the other with + linear kernel with input warping. + - FAST: An empty model config that utilizes MBM defaults. + """ + # Construct the surrogate spec. + if method == GenerationMethod.FAST: + model_configs = [ModelConfig(name="MBM defaults")] + elif method == GenerationMethod.BALANCED: + model_configs = [ + ModelConfig(name="MBM defaults"), + ModelConfig( + covar_module_class=LinearKernel, + input_transform_classes=[Warp, Normalize], + input_transform_options={"Normalize": {"center": 0.0}}, + name="LinearKernel with Warp", + ), + ] + else: + raise UnsupportedError(f"Unsupported generation method: {method}.") + + return GenerationNode( + node_name="MBM", + model_specs=[ + GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, + model_kwargs={ + "surrogate_spec": SurrogateSpec(model_configs=model_configs), + "torch_device": None + if torch_device is None + else torch.device(torch_device), + }, + ) + ], + should_deduplicate=True, + ) + + +def choose_generation_strategy( + method: GenerationMethod = GenerationMethod.FAST, + # Initialization options + initialization_budget: int | None = None, + initialization_random_seed: int | None = None, + use_existing_trials_for_initialization: bool = True, + min_observed_initialization_trials: int | None = None, + allow_exceeding_initialization_budget: bool = False, + # Misc options + torch_device: str | None = None, +) -> GenerationStrategy: + """Choose a generation strategy based on the properties of the experiment + and the inputs provided in ``gs_config``. + + NOTE: The behavior of this function is subject to change. It will be updated to + produce best general purpose generation strategies based on benchmarking results. + + Args: + gs_config: A ``GenerationStrategyConfig`` object that informs + the choice of generation strategy. + + Returns: + A generation strategy. + """ + # Handle the random search case. + if method == GenerationMethod.RANDOM_SEARCH: + return GenerationStrategy( + name="QuasiRandomSearch", + nodes=[ + GenerationNode( + node_name="Sobol", + model_specs=[ + GeneratorSpec( + model_enum=Generators.SOBOL, + model_kwargs={"seed": initialization_random_seed}, + ) + ], + ) + ], + ) + # Construct the nodes. + sobol_node = _get_sobol_node( + initialization_budget=initialization_budget, + initialization_random_seed=initialization_random_seed, + use_existing_trials_for_initialization=use_existing_trials_for_initialization, + min_observed_initialization_trials=min_observed_initialization_trials, + allow_exceeding_initialization_budget=allow_exceeding_initialization_budget, + ) + # Construct the MBM node. + mbm_node = _get_mbm_node( + method=method, + torch_device=torch_device, + ) + + return GenerationStrategy( + name=f"Sobol+MBM:{method.value}", + nodes=[sobol_node, mbm_node], + ) + + def _make_sobol_step( num_trials: int = -1, min_trials_observed: int | None = None, diff --git a/ax/generation_strategy/tests/test_dispatch_utils.py b/ax/generation_strategy/tests/test_dispatch_utils.py index 38e274a189f..5fe3bb3ffc4 100644 --- a/ax/generation_strategy/tests/test_dispatch_utils.py +++ b/ax/generation_strategy/tests/test_dispatch_utils.py @@ -13,22 +13,30 @@ import torch from ax.core.objective import MultiObjective from ax.core.optimization_config import MultiObjectiveOptimizationConfig +from ax.core.trial import Trial +from ax.core.trial_status import TrialStatus from ax.generation_strategy.dispatch_utils import ( _make_botorch_step, calculate_num_initialization_trials, + choose_generation_strategy, choose_generation_strategy_legacy, DEFAULT_BAYESIAN_PARALLELISM, + GenerationMethod, ) +from ax.generation_strategy.transition_criterion import MinTrials from ax.modelbridge.registry import Generators, MBM_X_trans, Mixed_transforms, Y_trans from ax.modelbridge.transforms.log_y import LogY from ax.modelbridge.transforms.winsorize import Winsorize from ax.models.random.sobol import SobolGenerator +from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec from ax.models.winsorization_config import WinsorizationConfig from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( + get_branin_experiment, get_branin_search_space, get_discrete_search_space, get_experiment, + get_experiment_with_observations, get_factorial_search_space, get_large_factorial_search_space, get_large_ordinal_search_space, @@ -36,12 +44,171 @@ run_branin_experiment_with_generation_strategy, ) from ax.utils.testing.mock import mock_botorch_optimize +from ax.utils.testing.utils import run_trials_with_gs +from botorch.models.transforms.input import Normalize, Warp +from gpytorch.kernels.linear_kernel import LinearKernel from pyre_extensions import assert_is_instance, none_throws class TestDispatchUtils(TestCase): """Tests that dispatching utilities correctly select generation strategies.""" + def test_choose_gs_random_search(self) -> None: + gs = choose_generation_strategy(method=GenerationMethod.RANDOM_SEARCH) + self.assertEqual(len(gs._nodes), 1) + sobol_node = gs._nodes[0] + self.assertEqual(len(sobol_node.model_specs), 1) + sobol_spec = sobol_node.model_specs[0] + self.assertEqual(sobol_spec.model_enum, Generators.SOBOL) + self.assertEqual(sobol_spec.model_kwargs, {"seed": None}) + self.assertEqual(sobol_node._transition_criteria, []) + # Make sure it generates. + run_trials_with_gs(experiment=get_branin_experiment(), gs=gs, num_trials=3) + + @mock_botorch_optimize + def test_choose_gs_fast_with_options(self) -> None: + gs = choose_generation_strategy( + method=GenerationMethod.FAST, + initialization_budget=3, + initialization_random_seed=0, + use_existing_trials_for_initialization=False, + min_observed_initialization_trials=4, + allow_exceeding_initialization_budget=True, + torch_device="cpu", + ) + self.assertEqual(len(gs._nodes), 2) + # Check the Sobol node & TC. + sobol_node = gs._nodes[0] + self.assertTrue(sobol_node.should_deduplicate) + self.assertEqual(len(sobol_node.model_specs), 1) + sobol_spec = sobol_node.model_specs[0] + self.assertEqual(sobol_spec.model_enum, Generators.SOBOL) + self.assertEqual(sobol_spec.model_kwargs, {"seed": 0}) + expected_tc = [ + MinTrials( + threshold=3, + transition_to="MBM", + block_gen_if_met=False, + block_transition_if_unmet=True, + use_all_trials_in_exp=False, + ), + MinTrials( + threshold=4, + transition_to="MBM", + block_gen_if_met=False, + block_transition_if_unmet=True, + use_all_trials_in_exp=False, + only_in_statuses=[TrialStatus.COMPLETED], + count_only_trials_with_data=True, + ), + ] + self.assertEqual(sobol_node._transition_criteria, expected_tc) + # Check the MBM node. + mbm_node = gs._nodes[1] + self.assertTrue(mbm_node.should_deduplicate) + self.assertEqual(len(mbm_node.model_specs), 1) + mbm_spec = mbm_node.model_specs[0] + self.assertEqual(mbm_spec.model_enum, Generators.BOTORCH_MODULAR) + expected_ss = SurrogateSpec(model_configs=[ModelConfig(name="MBM defaults")]) + self.assertEqual( + mbm_spec.model_kwargs, + {"surrogate_spec": expected_ss, "torch_device": torch.device("cpu")}, + ) + self.assertEqual(mbm_node._transition_criteria, []) + # Experiment with 2 observations. We should still generate 4 Sobol trials. + experiment = get_experiment_with_observations([[1.0], [2.0]]) + # Mark the existing trials as manual to prevent them from counting for Sobol. + for trial in experiment.trials.values(): + none_throws( + assert_is_instance(trial, Trial).generator_run + )._model_key = "Manual" + # Generate 5 trials and make sure they're from the correct nodes. + run_trials_with_gs(experiment=experiment, gs=gs, num_trials=5) + self.assertEqual(len(experiment.trials), 7) + for trial in experiment.trials.values(): + model_key = none_throws( + assert_is_instance(trial, Trial).generator_run + )._model_key + if trial.index < 2: + self.assertEqual(model_key, "Manual") + elif trial.index < 6: + self.assertEqual(model_key, "Sobol") + else: + self.assertEqual(model_key, "BoTorch") + + @mock_botorch_optimize + def test_choose_gs_balanced(self) -> None: + gs = choose_generation_strategy(method=GenerationMethod.BALANCED) + self.assertEqual(len(gs._nodes), 2) + # Check the Sobol node & TC. + sobol_node = gs._nodes[0] + self.assertTrue(sobol_node.should_deduplicate) + self.assertEqual(len(sobol_node.model_specs), 1) + sobol_spec = sobol_node.model_specs[0] + self.assertEqual(sobol_spec.model_enum, Generators.SOBOL) + self.assertEqual(sobol_spec.model_kwargs, {"seed": None}) + expected_tc = [ + MinTrials( + threshold=5, + transition_to="MBM", + block_gen_if_met=True, + block_transition_if_unmet=True, + use_all_trials_in_exp=True, + ), + MinTrials( + threshold=2, + transition_to="MBM", + block_gen_if_met=False, + block_transition_if_unmet=True, + use_all_trials_in_exp=True, + only_in_statuses=[TrialStatus.COMPLETED], + count_only_trials_with_data=True, + ), + ] + self.assertEqual(sobol_node._transition_criteria, expected_tc) + # Check the MBM node. + mbm_node = gs._nodes[1] + self.assertTrue(mbm_node.should_deduplicate) + self.assertEqual(len(mbm_node.model_specs), 1) + mbm_spec = mbm_node.model_specs[0] + self.assertEqual(mbm_spec.model_enum, Generators.BOTORCH_MODULAR) + expected_ss = SurrogateSpec( + model_configs=[ + ModelConfig(name="MBM defaults"), + ModelConfig( + covar_module_class=LinearKernel, + input_transform_classes=[Warp, Normalize], + input_transform_options={"Normalize": {"center": 0.0}}, + name="LinearKernel with Warp", + ), + ] + ) + self.assertEqual( + mbm_spec.model_kwargs, {"surrogate_spec": expected_ss, "torch_device": None} + ) + self.assertEqual(mbm_node._transition_criteria, []) + # Experiment with 2 observations. We should generate 3 more Sobol trials. + experiment = get_experiment_with_observations([[1.0], [2.0]]) + # Mark the existing trials as manual to prevent them from counting for Sobol. + # They'll still count for TC, since we use all trials in the experiment. + for trial in experiment.trials.values(): + none_throws( + assert_is_instance(trial, Trial).generator_run + )._model_key = "Manual" + # Generate 5 trials and make sure they're from the correct nodes. + run_trials_with_gs(experiment=experiment, gs=gs, num_trials=5) + self.assertEqual(len(experiment.trials), 7) + for trial in experiment.trials.values(): + model_key = none_throws( + assert_is_instance(trial, Trial).generator_run + )._model_key + if trial.index < 2: + self.assertEqual(model_key, "Manual") + elif trial.index < 5: + self.assertEqual(model_key, "Sobol") + else: + self.assertEqual(model_key, "BoTorch") + @mock_botorch_optimize def test_choose_generation_strategy_legacy(self) -> None: expected_transforms = [Winsorize] + MBM_X_trans + Y_trans diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index 207f6e38853..7cf3d02b45d 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -35,6 +35,7 @@ PercentileEarlyStoppingStrategy, ) from ax.exceptions.core import ObjectNotFoundError, UnsupportedError +from ax.generation_strategy.dispatch_utils import choose_generation_strategy from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.preview.api.configs import ( ExperimentConfig, @@ -50,7 +51,6 @@ optimization_config_from_string, ) from ax.preview.api.utils.storage import db_settings_from_storage_config -from ax.preview.modelbridge.dispatch_utils import choose_generation_strategy from ax.service.scheduler import Scheduler, SchedulerOptions from ax.service.utils.best_point_mixin import BestPointMixin from ax.service.utils.with_db_settings_base import WithDBSettingsBase @@ -179,7 +179,21 @@ def configure_generation_strategy( """ generation_strategy = choose_generation_strategy( - gs_config=generation_strategy_config + method=generation_strategy_config.method, + initialization_budget=generation_strategy_config.initialization_budget, + initialization_random_seed=( + generation_strategy_config.initialization_random_seed + ), + use_existing_trials_for_initialization=( + generation_strategy_config.use_existing_trials_for_initialization + ), + min_observed_initialization_trials=( + generation_strategy_config.min_observed_initialization_trials + ), + allow_exceeding_initialization_budget=( + generation_strategy_config.allow_exceeding_initialization_budget + ), + torch_device=generation_strategy_config.torch_device, ) # Necessary for storage implications, may be removed in the future diff --git a/ax/preview/api/configs.py b/ax/preview/api/configs.py index dcc29a08e2e..927b0da25a6 100644 --- a/ax/preview/api/configs.py +++ b/ax/preview/api/configs.py @@ -10,6 +10,8 @@ from enum import Enum from typing import Any +from ax.generation_strategy.dispatch_utils import GenerationMethod + from ax.preview.api.types import TParameterValue from ax.storage.registry_bundle import RegistryBundleBase @@ -83,31 +85,6 @@ class ExperimentConfig: owner: str | None = None -class GenerationMethod(Enum): - """An enum to specify the desired candidate generation method for the experiment. - This is used in ``GenerationStrategyConfig``, along with the properties of the - experiment, to determine the generation strategy to use for candidate generation. - - NOTE: New options should be rarely added to this enum. This is not intended to be - a list of generation strategies for the user to choose from. Instead, this enum - should only provide high level guidance to the underlying generation strategy - dispatch logic, which is responsible for determinining the exact details. - - Available options are: - BALANCED: A balanced generation method that may utilize (per-metric) model - selection to achieve a good model accuracy. This method excludes expensive - methods, such as the fully Bayesian SAASBO model. Used by default. - FAST: A faster generation method that uses the built-in defaults from the - Modular BoTorch Model without any model selection. - RANDOM_SEARCH: Primarily intended for pure exploration experiments, this - method utilizes quasi-random Sobol sequences for candidate generation. - """ - - BALANCED = "balanced" - FAST = "fast" - RANDOM_SEARCH = "random_search" - - @dataclass class GenerationStrategyConfig: """A dataclass used to configure the generation strategy used in the experiment. diff --git a/ax/preview/modelbridge/__init__.py b/ax/preview/modelbridge/__init__.py deleted file mode 100644 index 4b87eb9e4d0..00000000000 --- a/ax/preview/modelbridge/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. diff --git a/ax/preview/modelbridge/dispatch_utils.py b/ax/preview/modelbridge/dispatch_utils.py deleted file mode 100644 index 84d3d0453aa..00000000000 --- a/ax/preview/modelbridge/dispatch_utils.py +++ /dev/null @@ -1,171 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import torch -from ax.core.trial_status import TrialStatus -from ax.exceptions.core import UnsupportedError -from ax.generation_strategy.generation_strategy import ( - GenerationNode, - GenerationStrategy, -) -from ax.generation_strategy.model_spec import GeneratorSpec -from ax.generation_strategy.transition_criterion import MinTrials -from ax.modelbridge.registry import Generators -from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec -from ax.preview.api.configs import GenerationMethod, GenerationStrategyConfig -from botorch.models.transforms.input import Normalize, Warp -from gpytorch.kernels.linear_kernel import LinearKernel - - -def _get_sobol_node( - gs_config: GenerationStrategyConfig, -) -> GenerationNode: - """Constructs a Sobol node based on inputs from ``gs_config``. - The Sobol generator utilizes `initialization_random_seed` if specified. - - This node always transitions to "MBM", using the following transition criteria: - - MinTrials enforcing the initialization budget. - - If the initialization budget is not specified, it defaults to 5. - - The TC will not block generation if `allow_exceeding_initialization_budget` - is set to True. - - The TC is currently not restricted to any trial statuses and will - count all trials. - - `use_existing_trials_for_initialization` controls whether trials previously - attached to the experiment are counted as part of the initialization budget. - - MinTrials enforcing the minimum number of observed initialization trials. - - If `min_observed_initialization_trials` is not specified, it defaults - to `max(1, initialization_budget // 2)`. - - The TC currently only counts trials in status COMPLETED (with data attached) - as observed trials. - - `use_existing_trials_for_initialization` controls whether trials previously - attached to the experiment are counted as part of the required number of - observed initialization trials. - """ - # Set the default options. - initialization_budget = gs_config.initialization_budget - if initialization_budget is None: - initialization_budget = 5 - min_observed_initialization_trials = gs_config.min_observed_initialization_trials - if min_observed_initialization_trials is None: - min_observed_initialization_trials = max(1, initialization_budget // 2) - # Construct the transition criteria. - transition_criteria = [ - MinTrials( # This represents the initialization budget. - threshold=initialization_budget, - transition_to="MBM", - block_gen_if_met=(not gs_config.allow_exceeding_initialization_budget), - block_transition_if_unmet=True, - use_all_trials_in_exp=gs_config.use_existing_trials_for_initialization, - ), - MinTrials( # This represents minimum observed trials requirement. - threshold=min_observed_initialization_trials, - transition_to="MBM", - block_gen_if_met=False, - block_transition_if_unmet=True, - use_all_trials_in_exp=gs_config.use_existing_trials_for_initialization, - only_in_statuses=[TrialStatus.COMPLETED], - count_only_trials_with_data=True, - ), - ] - return GenerationNode( - node_name="Sobol", - model_specs=[ - GeneratorSpec( - model_enum=Generators.SOBOL, - model_kwargs={"seed": gs_config.initialization_random_seed}, - ) - ], - transition_criteria=transition_criteria, - should_deduplicate=True, - ) - - -def _get_mbm_node( - gs_config: GenerationStrategyConfig, -) -> GenerationNode: - """Constructs an MBM node based on the method specified in ``gs_config``. - - The ``SurrogateSpec`` takes the following form for the given method: - - BALANCED: Two model configs: one with MBM defaults, the other with - linear kernel with input warping. - - FAST: An empty model config that utilizes MBM defaults. - """ - # Construct the surrogate spec. - if gs_config.method == GenerationMethod.FAST: - model_configs = [ModelConfig(name="MBM defaults")] - elif gs_config.method == GenerationMethod.BALANCED: - model_configs = [ - ModelConfig(name="MBM defaults"), - ModelConfig( - covar_module_class=LinearKernel, - input_transform_classes=[Warp, Normalize], - input_transform_options={"Normalize": {"center": 0.0}}, - name="LinearKernel with Warp", - ), - ] - else: - raise UnsupportedError(f"Unsupported generation method: {gs_config.method}.") - torch_device = ( - None if gs_config.torch_device is None else torch.device(gs_config.torch_device) - ) - return GenerationNode( - node_name="MBM", - model_specs=[ - GeneratorSpec( - model_enum=Generators.BOTORCH_MODULAR, - model_kwargs={ - "surrogate_spec": SurrogateSpec(model_configs=model_configs), - "torch_device": torch_device, - }, - ) - ], - should_deduplicate=True, - ) - - -def choose_generation_strategy( - gs_config: GenerationStrategyConfig, -) -> GenerationStrategy: - """Choose a generation strategy based on the properties of the experiment - and the inputs provided in ``gs_config``. - - NOTE: The behavior of this function is subject to change. It will be updated to - produce best general purpose generation strategies based on benchmarking results. - - Args: - gs_config: A ``GenerationStrategyConfig`` object that informs - the choice of generation strategy. - - Returns: - A generation strategy. - """ - # Handle the random search case. - if gs_config.method == GenerationMethod.RANDOM_SEARCH: - return GenerationStrategy( - name="QuasiRandomSearch", - nodes=[ - GenerationNode( - node_name="Sobol", - model_specs=[ - GeneratorSpec( - model_enum=Generators.SOBOL, - model_kwargs={"seed": gs_config.initialization_random_seed}, - ) - ], - ) - ], - ) - # Construct the nodes. - sobol_node = _get_sobol_node(gs_config) - # Construct the MBM node. - mbm_node = _get_mbm_node(gs_config) - method_str = gs_config.method.value - return GenerationStrategy( - name=f"Sobol+MBM:{method_str}", - nodes=[sobol_node, mbm_node], - ) diff --git a/ax/preview/modelbridge/tests/test_preview_dispatch_utils.py b/ax/preview/modelbridge/tests/test_preview_dispatch_utils.py deleted file mode 100644 index f0080b7d933..00000000000 --- a/ax/preview/modelbridge/tests/test_preview_dispatch_utils.py +++ /dev/null @@ -1,189 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -from ax.core.trial import Trial -from ax.core.trial_status import TrialStatus -from ax.generation_strategy.transition_criterion import MinTrials -from ax.modelbridge.registry import Generators -from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec -from ax.preview.api.configs import GenerationMethod, GenerationStrategyConfig -from ax.preview.modelbridge.dispatch_utils import choose_generation_strategy -from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import ( - get_branin_experiment, - get_experiment_with_observations, -) -from ax.utils.testing.mock import mock_botorch_optimize -from ax.utils.testing.utils import run_trials_with_gs -from botorch.models.transforms.input import Normalize, Warp -from gpytorch.kernels.linear_kernel import LinearKernel -from pyre_extensions import assert_is_instance, none_throws - - -class TestDispatchUtils(TestCase): - def test_choose_gs_random_search(self) -> None: - gs_config = GenerationStrategyConfig( - method=GenerationMethod.RANDOM_SEARCH, - ) - gs = choose_generation_strategy(gs_config=gs_config) - self.assertEqual(len(gs._nodes), 1) - sobol_node = gs._nodes[0] - self.assertEqual(len(sobol_node.model_specs), 1) - sobol_spec = sobol_node.model_specs[0] - self.assertEqual(sobol_spec.model_enum, Generators.SOBOL) - self.assertEqual(sobol_spec.model_kwargs, {"seed": None}) - self.assertEqual(sobol_node._transition_criteria, []) - # Make sure it generates. - run_trials_with_gs(experiment=get_branin_experiment(), gs=gs, num_trials=3) - - @mock_botorch_optimize - def test_choose_gs_fast_with_options(self) -> None: - gs_config = GenerationStrategyConfig( - method=GenerationMethod.FAST, - initialization_budget=3, - initialization_random_seed=0, - use_existing_trials_for_initialization=False, - min_observed_initialization_trials=4, - allow_exceeding_initialization_budget=True, - torch_device="cpu", - ) - gs = choose_generation_strategy(gs_config=gs_config) - self.assertEqual(len(gs._nodes), 2) - # Check the Sobol node & TC. - sobol_node = gs._nodes[0] - self.assertTrue(sobol_node.should_deduplicate) - self.assertEqual(len(sobol_node.model_specs), 1) - sobol_spec = sobol_node.model_specs[0] - self.assertEqual(sobol_spec.model_enum, Generators.SOBOL) - self.assertEqual(sobol_spec.model_kwargs, {"seed": 0}) - expected_tc = [ - MinTrials( - threshold=3, - transition_to="MBM", - block_gen_if_met=False, - block_transition_if_unmet=True, - use_all_trials_in_exp=False, - ), - MinTrials( - threshold=4, - transition_to="MBM", - block_gen_if_met=False, - block_transition_if_unmet=True, - use_all_trials_in_exp=False, - only_in_statuses=[TrialStatus.COMPLETED], - count_only_trials_with_data=True, - ), - ] - self.assertEqual(sobol_node._transition_criteria, expected_tc) - # Check the MBM node. - mbm_node = gs._nodes[1] - self.assertTrue(mbm_node.should_deduplicate) - self.assertEqual(len(mbm_node.model_specs), 1) - mbm_spec = mbm_node.model_specs[0] - self.assertEqual(mbm_spec.model_enum, Generators.BOTORCH_MODULAR) - expected_ss = SurrogateSpec(model_configs=[ModelConfig(name="MBM defaults")]) - self.assertEqual( - mbm_spec.model_kwargs, - {"surrogate_spec": expected_ss, "torch_device": torch.device("cpu")}, - ) - self.assertEqual(mbm_node._transition_criteria, []) - # Experiment with 2 observations. We should still generate 4 Sobol trials. - experiment = get_experiment_with_observations([[1.0], [2.0]]) - # Mark the existing trials as manual to prevent them from counting for Sobol. - for trial in experiment.trials.values(): - none_throws( - assert_is_instance(trial, Trial).generator_run - )._model_key = "Manual" - # Generate 5 trials and make sure they're from the correct nodes. - run_trials_with_gs(experiment=experiment, gs=gs, num_trials=5) - self.assertEqual(len(experiment.trials), 7) - for trial in experiment.trials.values(): - model_key = none_throws( - assert_is_instance(trial, Trial).generator_run - )._model_key - if trial.index < 2: - self.assertEqual(model_key, "Manual") - elif trial.index < 6: - self.assertEqual(model_key, "Sobol") - else: - self.assertEqual(model_key, "BoTorch") - - @mock_botorch_optimize - def test_choose_gs_balanced(self) -> None: - gs = choose_generation_strategy( - gs_config=GenerationStrategyConfig(method=GenerationMethod.BALANCED) - ) - self.assertEqual(len(gs._nodes), 2) - # Check the Sobol node & TC. - sobol_node = gs._nodes[0] - self.assertTrue(sobol_node.should_deduplicate) - self.assertEqual(len(sobol_node.model_specs), 1) - sobol_spec = sobol_node.model_specs[0] - self.assertEqual(sobol_spec.model_enum, Generators.SOBOL) - self.assertEqual(sobol_spec.model_kwargs, {"seed": None}) - expected_tc = [ - MinTrials( - threshold=5, - transition_to="MBM", - block_gen_if_met=True, - block_transition_if_unmet=True, - use_all_trials_in_exp=True, - ), - MinTrials( - threshold=2, - transition_to="MBM", - block_gen_if_met=False, - block_transition_if_unmet=True, - use_all_trials_in_exp=True, - only_in_statuses=[TrialStatus.COMPLETED], - count_only_trials_with_data=True, - ), - ] - self.assertEqual(sobol_node._transition_criteria, expected_tc) - # Check the MBM node. - mbm_node = gs._nodes[1] - self.assertTrue(mbm_node.should_deduplicate) - self.assertEqual(len(mbm_node.model_specs), 1) - mbm_spec = mbm_node.model_specs[0] - self.assertEqual(mbm_spec.model_enum, Generators.BOTORCH_MODULAR) - expected_ss = SurrogateSpec( - model_configs=[ - ModelConfig(name="MBM defaults"), - ModelConfig( - covar_module_class=LinearKernel, - input_transform_classes=[Warp, Normalize], - input_transform_options={"Normalize": {"center": 0.0}}, - name="LinearKernel with Warp", - ), - ] - ) - self.assertEqual( - mbm_spec.model_kwargs, {"surrogate_spec": expected_ss, "torch_device": None} - ) - self.assertEqual(mbm_node._transition_criteria, []) - # Experiment with 2 observations. We should generate 3 more Sobol trials. - experiment = get_experiment_with_observations([[1.0], [2.0]]) - # Mark the existing trials as manual to prevent them from counting for Sobol. - # They'll still count for TC, since we use all trials in the experiment. - for trial in experiment.trials.values(): - none_throws( - assert_is_instance(trial, Trial).generator_run - )._model_key = "Manual" - # Generate 5 trials and make sure they're from the correct nodes. - run_trials_with_gs(experiment=experiment, gs=gs, num_trials=5) - self.assertEqual(len(experiment.trials), 7) - for trial in experiment.trials.values(): - model_key = none_throws( - assert_is_instance(trial, Trial).generator_run - )._model_key - if trial.index < 2: - self.assertEqual(model_key, "Manual") - elif trial.index < 5: - self.assertEqual(model_key, "Sobol") - else: - self.assertEqual(model_key, "BoTorch") From 2247ad26033239d362b93126cf1ceb143f4b0b07 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Wed, 5 Mar 2025 20:56:18 -0800 Subject: [PATCH 3/3] Move ax.preview.api to ax.api (#3466) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3466 Differential Revision: D70647192 --- ax/analysis/tests/test_metric_summary.py | 4 +-- .../tests/test_search_space_summary.py | 6 ++-- ax/analysis/tests/test_summary.py | 4 +-- ax/{preview => }/api/__init__.py | 6 ++-- ax/{preview => }/api/client.py | 26 ++++++++-------- ax/{preview => }/api/configs.py | 4 +-- ax/{preview => }/api/protocols/__init__.py | 4 +-- ax/{preview => }/api/protocols/metric.py | 2 +- ax/{preview => }/api/protocols/runner.py | 5 ++-- ax/{preview => }/api/protocols/utils.py | 2 +- ax/{preview => }/api/tests/test_client.py | 26 ++++++++-------- ax/{preview => }/api/types.py | 0 ax/{preview => api/utils}/__init__.py | 0 .../utils/instantiation}/__init__.py | 0 .../api/utils/instantiation/from_config.py | 16 +++++----- .../api/utils/instantiation/from_string.py | 0 .../instantiation/tests/test_from_config.py | 24 +++++++-------- .../instantiation/tests/test_from_string.py | 14 ++++----- ax/{preview => }/api/utils/storage.py | 2 +- .../api/utils/instantiation/__init__.py | 5 ---- sphinx/source/{preview.rst => api.rst} | 30 +++++++++---------- tutorials/ask_tell/ask_tell.ipynb | 4 +-- tutorials/automl/automl.ipynb | 4 +-- tutorials/closed_loop/closed_loop.ipynb | 10 +++---- tutorials/early_stopping/early_stopping.ipynb | 4 +-- .../human_in_the_loop/human_in_the_loop.ipynb | 4 +-- 26 files changed, 100 insertions(+), 106 deletions(-) rename ax/{preview => }/api/__init__.py (83%) rename ax/{preview => }/api/client.py (99%) rename ax/{preview => }/api/configs.py (99%) rename ax/{preview => }/api/protocols/__init__.py (71%) rename ax/{preview => }/api/protocols/metric.py (95%) rename ax/{preview => }/api/protocols/runner.py (94%) rename ax/{preview => }/api/protocols/utils.py (99%) rename ax/{preview => }/api/tests/test_client.py (99%) rename ax/{preview => }/api/types.py (100%) rename ax/{preview => api/utils}/__init__.py (100%) rename ax/{preview/api/utils => api/utils/instantiation}/__init__.py (100%) rename ax/{preview => }/api/utils/instantiation/from_config.py (97%) rename ax/{preview => }/api/utils/instantiation/from_string.py (100%) rename ax/{preview => }/api/utils/instantiation/tests/test_from_config.py (99%) rename ax/{preview => }/api/utils/instantiation/tests/test_from_string.py (99%) rename ax/{preview => }/api/utils/storage.py (95%) delete mode 100644 ax/preview/api/utils/instantiation/__init__.py rename sphinx/source/{preview.rst => api.rst} (60%) diff --git a/ax/analysis/tests/test_metric_summary.py b/ax/analysis/tests/test_metric_summary.py index 9412a8b7dec..5ca36f2dcda 100644 --- a/ax/analysis/tests/test_metric_summary.py +++ b/ax/analysis/tests/test_metric_summary.py @@ -8,10 +8,10 @@ import pandas as pd from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.metric_summary import MetricSummary +from ax.api.client import Client +from ax.api.configs import ExperimentConfig from ax.core.metric import Metric from ax.exceptions.core import UserInputError -from ax.preview.api.client import Client -from ax.preview.api.configs import ExperimentConfig from ax.utils.common.testutils import TestCase diff --git a/ax/analysis/tests/test_search_space_summary.py b/ax/analysis/tests/test_search_space_summary.py index 415f94951b2..4af107eac95 100644 --- a/ax/analysis/tests/test_search_space_summary.py +++ b/ax/analysis/tests/test_search_space_summary.py @@ -8,15 +8,15 @@ import pandas as pd from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.search_space_summary import SearchSpaceSummary -from ax.exceptions.core import UserInputError -from ax.preview.api.client import Client -from ax.preview.api.configs import ( +from ax.api.client import Client +from ax.api.configs import ( ChoiceParameterConfig, ExperimentConfig, ParameterScaling, ParameterType, RangeParameterConfig, ) +from ax.exceptions.core import UserInputError from ax.utils.common.testutils import TestCase diff --git a/ax/analysis/tests/test_summary.py b/ax/analysis/tests/test_summary.py index 26ef44c16f6..1260ed2b9d5 100644 --- a/ax/analysis/tests/test_summary.py +++ b/ax/analysis/tests/test_summary.py @@ -9,10 +9,10 @@ import pandas as pd from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel from ax.analysis.summary import Summary +from ax.api.client import Client +from ax.api.configs import ExperimentConfig, ParameterType, RangeParameterConfig from ax.core.trial import Trial from ax.exceptions.core import UserInputError -from ax.preview.api.client import Client -from ax.preview.api.configs import ExperimentConfig, ParameterType, RangeParameterConfig from ax.utils.common.testutils import TestCase from pyre_extensions import assert_is_instance, none_throws diff --git a/ax/preview/api/__init__.py b/ax/api/__init__.py similarity index 83% rename from ax/preview/api/__init__.py rename to ax/api/__init__.py index 984729ced93..6af52d74068 100644 --- a/ax/preview/api/__init__.py +++ b/ax/api/__init__.py @@ -6,8 +6,8 @@ # pyre-strict -from ax.preview.api.client import Client -from ax.preview.api.configs import ( +from ax.api.client import Client +from ax.api.configs import ( ChoiceParameterConfig, ExperimentConfig, GenerationStrategyConfig, @@ -17,7 +17,7 @@ RangeParameterConfig, StorageConfig, ) -from ax.preview.api.types import TOutcome, TParameterization +from ax.api.types import TOutcome, TParameterization __all__ = [ "Client", diff --git a/ax/preview/api/client.py b/ax/api/client.py similarity index 99% rename from ax/preview/api/client.py rename to ax/api/client.py index 7cf3d02b45d..dcd9e1e35fd 100644 --- a/ax/preview/api/client.py +++ b/ax/api/client.py @@ -21,6 +21,18 @@ markdown_analysis_card_from_analysis_e, ) from ax.analysis.utils import choose_analyses +from ax.api.configs import ( + ExperimentConfig, + GenerationStrategyConfig, + OrchestrationConfig, + StorageConfig, +) +from ax.api.protocols.metric import IMetric +from ax.api.protocols.runner import IRunner +from ax.api.types import TOutcome, TParameterization +from ax.api.utils.instantiation.from_config import experiment_from_config +from ax.api.utils.instantiation.from_string import optimization_config_from_string +from ax.api.utils.storage import db_settings_from_storage_config from ax.core.experiment import Experiment from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective, ScalarizedObjective @@ -37,20 +49,6 @@ from ax.exceptions.core import ObjectNotFoundError, UnsupportedError from ax.generation_strategy.dispatch_utils import choose_generation_strategy from ax.generation_strategy.generation_strategy import GenerationStrategy -from ax.preview.api.configs import ( - ExperimentConfig, - GenerationStrategyConfig, - OrchestrationConfig, - StorageConfig, -) -from ax.preview.api.protocols.metric import IMetric -from ax.preview.api.protocols.runner import IRunner -from ax.preview.api.types import TOutcome, TParameterization -from ax.preview.api.utils.instantiation.from_config import experiment_from_config -from ax.preview.api.utils.instantiation.from_string import ( - optimization_config_from_string, -) -from ax.preview.api.utils.storage import db_settings_from_storage_config from ax.service.scheduler import Scheduler, SchedulerOptions from ax.service.utils.best_point_mixin import BestPointMixin from ax.service.utils.with_db_settings_base import WithDBSettingsBase diff --git a/ax/preview/api/configs.py b/ax/api/configs.py similarity index 99% rename from ax/preview/api/configs.py rename to ax/api/configs.py index 927b0da25a6..f4e8c72bea9 100644 --- a/ax/preview/api/configs.py +++ b/ax/api/configs.py @@ -10,9 +10,9 @@ from enum import Enum from typing import Any -from ax.generation_strategy.dispatch_utils import GenerationMethod +from ax.api.types import TParameterValue -from ax.preview.api.types import TParameterValue +from ax.generation_strategy.dispatch_utils import GenerationMethod from ax.storage.registry_bundle import RegistryBundleBase diff --git a/ax/preview/api/protocols/__init__.py b/ax/api/protocols/__init__.py similarity index 71% rename from ax/preview/api/protocols/__init__.py rename to ax/api/protocols/__init__.py index 0a8213e7d18..18f7a0b7fa7 100644 --- a/ax/preview/api/protocols/__init__.py +++ b/ax/api/protocols/__init__.py @@ -6,8 +6,8 @@ # pyre-strict -from ax.preview.api.protocols.metric import IMetric -from ax.preview.api.protocols.runner import IRunner +from ax.api.protocols.metric import IMetric +from ax.api.protocols.runner import IRunner __all__ = [ "IMetric", diff --git a/ax/preview/api/protocols/metric.py b/ax/api/protocols/metric.py similarity index 95% rename from ax/preview/api/protocols/metric.py rename to ax/api/protocols/metric.py index d354408a28f..0d7f5a18a3a 100644 --- a/ax/preview/api/protocols/metric.py +++ b/ax/api/protocols/metric.py @@ -9,7 +9,7 @@ from collections.abc import Mapping from typing import Any -from ax.preview.api.protocols.utils import _APIMetric +from ax.api.protocols.utils import _APIMetric from pyre_extensions import override diff --git a/ax/preview/api/protocols/runner.py b/ax/api/protocols/runner.py similarity index 94% rename from ax/preview/api/protocols/runner.py rename to ax/api/protocols/runner.py index b6e9e426314..a289ace98e4 100644 --- a/ax/preview/api/protocols/runner.py +++ b/ax/api/protocols/runner.py @@ -9,9 +9,10 @@ from collections.abc import Mapping from typing import Any +from ax.api.protocols.utils import _APIRunner +from ax.api.types import TParameterization + from ax.core.trial_status import TrialStatus -from ax.preview.api.protocols.utils import _APIRunner -from ax.preview.api.types import TParameterization from pyre_extensions import override diff --git a/ax/preview/api/protocols/utils.py b/ax/api/protocols/utils.py similarity index 99% rename from ax/preview/api/protocols/utils.py rename to ax/api/protocols/utils.py index 55615a3f364..c6134aecc33 100644 --- a/ax/preview/api/protocols/utils.py +++ b/ax/api/protocols/utils.py @@ -13,6 +13,7 @@ from typing import Any import pandas as pd +from ax.api.types import TParameterization from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.map_data import MapData, MapKeyInfo @@ -21,7 +22,6 @@ from ax.core.runner import Runner from ax.core.trial import Trial from ax.exceptions.storage import JSONEncodeError -from ax.preview.api.types import TParameterization from ax.utils.common.result import Err, Ok from pyre_extensions import assert_is_instance, none_throws, override diff --git a/ax/preview/api/tests/test_client.py b/ax/api/tests/test_client.py similarity index 99% rename from ax/preview/api/tests/test_client.py rename to ax/api/tests/test_client.py index 77af5ab5998..dcc3acd1496 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/api/tests/test_client.py @@ -12,6 +12,19 @@ import pandas as pd from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot +from ax.api.client import Client +from ax.api.configs import ( + ChoiceParameterConfig, + ExperimentConfig, + GenerationStrategyConfig, + OrchestrationConfig, + ParameterType, + RangeParameterConfig, + StorageConfig, +) +from ax.api.protocols.metric import IMetric +from ax.api.protocols.runner import IRunner +from ax.api.types import TParameterization from ax.core.experiment import Experiment from ax.core.formatting_utils import DataType @@ -31,19 +44,6 @@ from ax.core.trial_status import TrialStatus from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy from ax.exceptions.core import UnsupportedError -from ax.preview.api.client import Client -from ax.preview.api.configs import ( - ChoiceParameterConfig, - ExperimentConfig, - GenerationStrategyConfig, - OrchestrationConfig, - ParameterType, - RangeParameterConfig, - StorageConfig, -) -from ax.preview.api.protocols.metric import IMetric -from ax.preview.api.protocols.runner import IRunner -from ax.preview.api.types import TParameterization from ax.storage.sqa_store.db import init_test_engine_and_session_factory from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( diff --git a/ax/preview/api/types.py b/ax/api/types.py similarity index 100% rename from ax/preview/api/types.py rename to ax/api/types.py diff --git a/ax/preview/__init__.py b/ax/api/utils/__init__.py similarity index 100% rename from ax/preview/__init__.py rename to ax/api/utils/__init__.py diff --git a/ax/preview/api/utils/__init__.py b/ax/api/utils/instantiation/__init__.py similarity index 100% rename from ax/preview/api/utils/__init__.py rename to ax/api/utils/instantiation/__init__.py diff --git a/ax/preview/api/utils/instantiation/from_config.py b/ax/api/utils/instantiation/from_config.py similarity index 97% rename from ax/preview/api/utils/instantiation/from_config.py rename to ax/api/utils/instantiation/from_config.py index b414a314413..c17d69dba77 100644 --- a/ax/preview/api/utils/instantiation/from_config.py +++ b/ax/api/utils/instantiation/from_config.py @@ -7,6 +7,14 @@ import numpy as np +from ax.api.configs import ( + ChoiceParameterConfig, + ExperimentConfig, + ParameterScaling, + ParameterType, + RangeParameterConfig, +) +from ax.api.utils.instantiation.from_string import parse_parameter_constraint from ax.core.experiment import Experiment @@ -21,14 +29,6 @@ from ax.core.parameter_constraint import validate_constraint_parameters from ax.core.search_space import HierarchicalSearchSpace, SearchSpace from ax.exceptions.core import UserInputError -from ax.preview.api.configs import ( - ChoiceParameterConfig, - ExperimentConfig, - ParameterScaling, - ParameterType, - RangeParameterConfig, -) -from ax.preview.api.utils.instantiation.from_string import parse_parameter_constraint def parameter_from_config( diff --git a/ax/preview/api/utils/instantiation/from_string.py b/ax/api/utils/instantiation/from_string.py similarity index 100% rename from ax/preview/api/utils/instantiation/from_string.py rename to ax/api/utils/instantiation/from_string.py diff --git a/ax/preview/api/utils/instantiation/tests/test_from_config.py b/ax/api/utils/instantiation/tests/test_from_config.py similarity index 99% rename from ax/preview/api/utils/instantiation/tests/test_from_config.py rename to ax/api/utils/instantiation/tests/test_from_config.py index 05ce56f239d..b5b2cda597f 100644 --- a/ax/preview/api/utils/instantiation/tests/test_from_config.py +++ b/ax/api/utils/instantiation/tests/test_from_config.py @@ -5,6 +5,18 @@ # pyre-strict +from ax.api.configs import ( + ChoiceParameterConfig, + ExperimentConfig, + ParameterScaling, + ParameterType, + RangeParameterConfig, +) +from ax.api.utils.instantiation.from_config import ( + _parameter_type_converter, + experiment_from_config, + parameter_from_config, +) from ax.core.experiment import Experiment from ax.core.formatting_utils import DataType from ax.core.parameter import ( @@ -16,18 +28,6 @@ from ax.core.parameter_constraint import ParameterConstraint from ax.core.search_space import HierarchicalSearchSpace, SearchSpace from ax.exceptions.core import UserInputError -from ax.preview.api.configs import ( - ChoiceParameterConfig, - ExperimentConfig, - ParameterScaling, - ParameterType, - RangeParameterConfig, -) -from ax.preview.api.utils.instantiation.from_config import ( - _parameter_type_converter, - experiment_from_config, - parameter_from_config, -) from ax.utils.common.testutils import TestCase diff --git a/ax/preview/api/utils/instantiation/tests/test_from_string.py b/ax/api/utils/instantiation/tests/test_from_string.py similarity index 99% rename from ax/preview/api/utils/instantiation/tests/test_from_string.py rename to ax/api/utils/instantiation/tests/test_from_string.py index 89e31201c39..bfee370a169 100644 --- a/ax/preview/api/utils/instantiation/tests/test_from_string.py +++ b/ax/api/utils/instantiation/tests/test_from_string.py @@ -5,6 +5,13 @@ # pyre-strict +from ax.api.utils.instantiation.from_string import ( + _sanitize_dot, + optimization_config_from_string, + parse_objective, + parse_outcome_constraint, + parse_parameter_constraint, +) from ax.core.map_metric import MapMetric from ax.core.objective import MultiObjective, Objective, ScalarizedObjective from ax.core.optimization_config import ( @@ -19,13 +26,6 @@ ) from ax.core.parameter_constraint import ParameterConstraint from ax.exceptions.core import UserInputError -from ax.preview.api.utils.instantiation.from_string import ( - _sanitize_dot, - optimization_config_from_string, - parse_objective, - parse_outcome_constraint, - parse_parameter_constraint, -) from ax.utils.common.testutils import TestCase diff --git a/ax/preview/api/utils/storage.py b/ax/api/utils/storage.py similarity index 95% rename from ax/preview/api/utils/storage.py rename to ax/api/utils/storage.py index 3e6f4254aad..9be4a9b9a8f 100644 --- a/ax/preview/api/utils/storage.py +++ b/ax/api/utils/storage.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from ax.preview.api.configs import StorageConfig +from ax.api.configs import StorageConfig from ax.storage.sqa_store.decoder import Decoder from ax.storage.sqa_store.encoder import Encoder from ax.storage.sqa_store.sqa_config import SQAConfig diff --git a/ax/preview/api/utils/instantiation/__init__.py b/ax/preview/api/utils/instantiation/__init__.py deleted file mode 100644 index 4b87eb9e4d0..00000000000 --- a/ax/preview/api/utils/instantiation/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. diff --git a/sphinx/source/preview.rst b/sphinx/source/api.rst similarity index 60% rename from sphinx/source/preview.rst rename to sphinx/source/api.rst index fcd1deb41e4..61a933d6cbb 100644 --- a/sphinx/source/preview.rst +++ b/sphinx/source/api.rst @@ -1,20 +1,20 @@ .. role:: hidden :class: hidden-section -ax.preview +ax.api ========== -.. automodule:: ax.preview -.. currentmodule:: ax.preview +.. automodule:: ax.api +.. currentmodule:: ax.api -A preview of future Ax API +The Ax API -------------------------- IMetric ~~~~~~~ -.. automodule:: ax.preview.api.protocols.metric +.. automodule:: ax.api.protocols.metric :members: :undoc-members: :show-inheritance: @@ -22,7 +22,7 @@ IMetric IRunner ~~~~~~~ -.. automodule:: ax.preview.api.protocols.runner +.. automodule:: ax.api.protocols.runner :members: :undoc-members: :show-inheritance: @@ -31,7 +31,7 @@ IRunner Utils ~~~~~~~ -.. automodule:: ax.preview.api.protocols.utils +.. automodule:: ax.api.protocols.utils :members: :undoc-members: :show-inheritance: @@ -40,7 +40,7 @@ Utils Client ~~~~~~ -.. automodule:: ax.preview.api.client +.. automodule:: ax.api.client :members: :undoc-members: :show-inheritance: @@ -49,7 +49,7 @@ Client Configs ~~~~~~~ -.. automodule:: ax.preview.api.configs +.. automodule:: ax.api.configs :members: :undoc-members: :show-inheritance: @@ -57,7 +57,7 @@ Configs Types ~~~~~ -.. automodule:: ax.preview.api.types +.. automodule:: ax.api.types :members: :undoc-members: :show-inheritance: @@ -65,7 +65,7 @@ Types From Config ~~~~~~~~~~~ -.. automodule:: ax.preview.api.utils.instantiation.from_config +.. automodule:: ax.api.utils.instantiation.from_config :members: :undoc-members: :show-inheritance: @@ -73,7 +73,7 @@ From Config From String ~~~~~~~~~~~ -.. automodule:: ax.preview.api.utils.instantiation.from_string +.. automodule:: ax.api.utils.instantiation.from_string :members: :undoc-members: :show-inheritance: @@ -82,7 +82,7 @@ From String Adapter ~~~~~~~~~~~ -.. automodule:: ax.preview.modelbridge +.. automodule:: ax.api.modelbridge :members: :undoc-members: :show-inheritance: @@ -90,7 +90,7 @@ Adapter Dispatch Utils ~~~~~~~~~~~~~~ -.. automodule:: ax.preview.modelbridge.dispatch_utils +.. automodule:: ax.api.modelbridge.dispatch_utils :members: :undoc-members: :show-inheritance: @@ -98,7 +98,7 @@ Dispatch Utils Storage Utils ~~~~~~~~~~~~~ -.. automodule:: ax.preview.api.utils.storage +.. automodule:: ax.api.utils.storage :members: :undoc-members: :show-inheritance: diff --git a/tutorials/ask_tell/ask_tell.ipynb b/tutorials/ask_tell/ask_tell.ipynb index 6530c86c690..b06a909bc1d 100644 --- a/tutorials/ask_tell/ask_tell.ipynb +++ b/tutorials/ask_tell/ask_tell.ipynb @@ -65,8 +65,8 @@ "outputs": [], "source": [ "import numpy as np\n", - "from ax.preview.api.client import Client\n", - "from ax.preview.api.configs import (\n", + "from ax.api.client import Client\n", + "from ax.api.configs import (\n", " ExperimentConfig,\n", " RangeParameterConfig,\n", " ParameterType,\n", diff --git a/tutorials/automl/automl.ipynb b/tutorials/automl/automl.ipynb index d3a963d8375..3d681e36dfb 100644 --- a/tutorials/automl/automl.ipynb +++ b/tutorials/automl/automl.ipynb @@ -76,8 +76,8 @@ "import sklearn.datasets\n", "import sklearn.linear_model\n", "import sklearn.model_selection\n", - "from ax.preview.api.client import Client\n", - "from ax.preview.api.configs import (\n", + "from ax.api.client import Client\n", + "from ax.api.configs import (\n", " ChoiceParameterConfig,\n", " ExperimentConfig,\n", " GenerationMethod,\n", diff --git a/tutorials/closed_loop/closed_loop.ipynb b/tutorials/closed_loop/closed_loop.ipynb index 1edd5cba8a2..3364d015b26 100644 --- a/tutorials/closed_loop/closed_loop.ipynb +++ b/tutorials/closed_loop/closed_loop.ipynb @@ -76,16 +76,16 @@ "from typing import Any, Mapping\n", "\n", "import numpy as np\n", - "from ax.preview.api.client import Client\n", - "from ax.preview.api.configs import (\n", + "from ax.api.client import Client\n", + "from ax.api.configs import (\n", " ExperimentConfig,\n", " OrchestrationConfig,\n", " ParameterType,\n", " RangeParameterConfig,\n", ")\n", - "from ax.preview.api.protocols.metric import IMetric\n", - "from ax.preview.api.protocols.runner import IRunner, TrialStatus\n", - "from ax.preview.api.types import TParameterization" + "from ax.api.protocols.metric import IMetric\n", + "from ax.api.protocols.runner import IRunner, TrialStatus\n", + "from ax.api.types import TParameterization" ] }, { diff --git a/tutorials/early_stopping/early_stopping.ipynb b/tutorials/early_stopping/early_stopping.ipynb index 3c6255c03c6..abf4c2f1530 100644 --- a/tutorials/early_stopping/early_stopping.ipynb +++ b/tutorials/early_stopping/early_stopping.ipynb @@ -73,8 +73,8 @@ "import plotly.express as px\n", "import plotly.graph_objects as go\n", "from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy\n", - "from ax.preview.api.client import Client\n", - "from ax.preview.api.configs import ExperimentConfig, ParameterType, RangeParameterConfig" + "from ax.api.client import Client\n", + "from ax.api.configs import ExperimentConfig, ParameterType, RangeParameterConfig" ] }, { diff --git a/tutorials/human_in_the_loop/human_in_the_loop.ipynb b/tutorials/human_in_the_loop/human_in_the_loop.ipynb index 132f8f4b86a..721cc141531 100644 --- a/tutorials/human_in_the_loop/human_in_the_loop.ipynb +++ b/tutorials/human_in_the_loop/human_in_the_loop.ipynb @@ -68,8 +68,8 @@ "source": [ "import pandas as pd\n", "\n", - "from ax.preview.api.client import Client\n", - "from ax.preview.api.configs import ExperimentConfig, RangeParameterConfig, ChoiceParameterConfig, ParameterType, GenerationStrategyConfig" + "from ax.api.client import Client\n", + "from ax.api.configs import ExperimentConfig, RangeParameterConfig, ChoiceParameterConfig, ParameterType, GenerationStrategyConfig" ] }, {