diff --git a/ax/modelbridge/tests/test_prediction_utils.py b/ax/modelbridge/tests/test_prediction_utils.py index 0354cdffe47..d1929867f1e 100644 --- a/ax/modelbridge/tests/test_prediction_utils.py +++ b/ax/modelbridge/tests/test_prediction_utils.py @@ -28,7 +28,7 @@ def test_predict_at_point(self) -> None: observation_features = ObservationFeatures(parameters={"x1": 0.3, "x2": 0.5}) y_hat, se_hat = predict_at_point( - model=none_throws(ax_client.generation_strategy.model), + model=none_throws(ax_client.standard_generation_strategy.model), obsf=observation_features, metric_names={"test_metric1"}, ) @@ -37,7 +37,7 @@ def test_predict_at_point(self) -> None: self.assertEqual(len(se_hat), 1) y_hat, se_hat = predict_at_point( - model=none_throws(ax_client.generation_strategy.model), + model=none_throws(ax_client.standard_generation_strategy.model), obsf=observation_features, metric_names={"test_metric1", "test_metric2", "test_metric:agg"}, scalarized_metric_config=[ @@ -51,7 +51,7 @@ def test_predict_at_point(self) -> None: self.assertEqual(len(se_hat), 3) y_hat, se_hat = predict_at_point( - model=none_throws(ax_client.generation_strategy.model), + model=none_throws(ax_client.standard_generation_strategy.model), obsf=observation_features, metric_names={"test_metric1"}, scalarized_metric_config=[ @@ -75,7 +75,7 @@ def test_predict_by_features(self) -> None: 20: ObservationFeatures(parameters={"x1": 0.8, "x2": 0.5}), } predictions_map = predict_by_features( - model=none_throws(ax_client.generation_strategy.model), + model=none_throws(ax_client.standard_generation_strategy.model), label_to_feature_dict=observation_features_dict, metric_names={"test_metric1"}, ) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index dd63d35349f..67881c0d1d3 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -66,6 +66,7 @@ from ax.plot.feature_importances import plot_feature_importance_by_feature from ax.plot.helper import _format_dict from ax.plot.trace import optimization_trace_single_method +from ax.service.utils.analysis_base import AnalysisBase from ax.service.utils.best_point_mixin import BestPointMixin from ax.service.utils.instantiation import ( FixedFeatures, @@ -73,7 +74,7 @@ ObjectiveProperties, ) from ax.service.utils.report_utils import exp_to_df -from ax.service.utils.with_db_settings_base import DBSettings, WithDBSettingsBase +from ax.service.utils.with_db_settings_base import DBSettings from ax.storage.json_store.decoder import ( generation_strategy_from_json, object_from_json, @@ -108,7 +109,7 @@ ) -class AxClient(WithDBSettingsBase, BestPointMixin, InstantiationBase): +class AxClient(AnalysisBase, BestPointMixin, InstantiationBase): """ Convenience handler for management of experimentation cycle through a service-like API. External system manages scheduling of the cycle and makes @@ -598,8 +599,8 @@ def get_next_trial( # TODO[T79183560]: Ensure correct handling of generator run when using # foreign keys. self._update_generation_strategy_in_db_if_possible( - generation_strategy=self.generation_strategy, - new_generator_runs=[self.generation_strategy._generator_runs[-1]], + generation_strategy=self.standard_generation_strategy, + new_generator_runs=[self.standard_generation_strategy._generator_runs[-1]], ) return none_throws(trial.arm).parameters, trial.index @@ -624,7 +625,7 @@ def get_current_trial_generation_limit(self) -> tuple[int, bool]: if self.generation_strategy._experiment is None: self.generation_strategy.experiment = self.experiment - return self.generation_strategy.current_generator_run_limit() + return self.standard_generation_strategy.current_generator_run_limit() def get_next_trials( self, @@ -949,7 +950,7 @@ def get_max_parallelism(self) -> list[tuple[int, int]]: Mapping of form {num_trials -> max_parallelism_setting}. """ parallelism_settings = [] - for step in self.generation_strategy._steps: + for step in self.standard_generation_strategy._steps: parallelism_settings.append( (step.num_trials, step.max_parallelism or step.num_trials) ) @@ -1070,7 +1071,7 @@ def get_contour_plot( raise ValueError( f'Metric "{metric_name}" is not associated with this optimization.' ) - if self.generation_strategy.model is not None: + if self.standard_generation_strategy.model is not None: try: logger.info( f"Retrieving contour plot with parameter '{param_x}' on X-axis " @@ -1078,7 +1079,7 @@ def get_contour_plot( "Remaining parameters are affixed to the middle of their range." ) return plot_contour( - model=none_throws(self.generation_strategy.model), + model=none_throws(self.standard_generation_strategy.model), param_x=param_x, param_y=param_y, metric_name=metric_name, @@ -1088,8 +1089,8 @@ def get_contour_plot( # Some models don't implement '_predict', which is needed # for the contour plots. logger.info( - f"Model {self.generation_strategy.model} does not implement " - "`predict`, so it cannot be used to generate a response " + f"Model {self.standard_generation_strategy.model} does not " + "implement `predict`, so it cannot be used to generate a response " "surface plot." ) raise UnsupportedPlotError( @@ -1111,14 +1112,14 @@ def get_feature_importances(self, relative: bool = True) -> AxPlotConfig: """ if not self.experiment.trials: raise ValueError("Cannot generate plot as there are no trials.") - cur_model = self.generation_strategy.model + cur_model = self.standard_generation_strategy.model if cur_model is not None: try: return plot_feature_importance_by_feature(cur_model, relative=relative) except NotImplementedError: logger.info( - f"Model {self.generation_strategy.model} does not implement " - "`feature_importances`, so it cannot be used to generate " + f"Model {self.standard_generation_strategy.model} does not " + "implement `feature_importances`, so it cannot be used to generate " "this plot. Only certain models, implement feature importances." ) @@ -1246,7 +1247,8 @@ def get_model_predictions( else set(none_throws(self.experiment.metrics).keys()) ) model = none_throws( - self.generation_strategy.model, "No model has been instantiated yet." + self.standard_generation_strategy.model, + "No model has been instantiated yet.", ) # Construct a dictionary that maps from a label to an @@ -1305,8 +1307,8 @@ def fit_model(self) -> None: "At least one trial must be completed with data to fit a model." ) # Check if we should transition before generating the next candidate. - self.generation_strategy._maybe_transition_to_next_node() - self.generation_strategy._fit_current_model(data=None) + self.standard_generation_strategy._maybe_transition_to_next_node() + self.standard_generation_strategy._fit_current_model(data=None) def verify_trial_parameterization( self, trial_index: int, parameterization: TParameterization @@ -1495,29 +1497,10 @@ def from_json_snapshot( # ---------------------- Private helper methods. --------------------- - @property - def experiment(self) -> Experiment: - """Returns the experiment set on this Ax client.""" - return none_throws( - self._experiment, - ( - "Experiment not set on Ax client. Must first " - "call load_experiment or create_experiment to use handler functions." - ), - ) - def get_trial(self, trial_index: int) -> Trial: """Return a trial on experiment cast as Trial""" return checked_cast(Trial, self.experiment.trials[trial_index]) - @property - def generation_strategy(self) -> GenerationStrategy: - """Returns the generation strategy, set on this experiment.""" - return none_throws( - self._generation_strategy, - "No generation strategy has been set on this optimization yet.", - ) - @property def objective(self) -> Objective: return none_throws(self.experiment.optimization_config).objective @@ -1585,7 +1568,7 @@ def get_best_trial( ) -> tuple[int, TParameterization, TModelPredictArm | None] | None: return self._get_best_trial( experiment=self.experiment, - generation_strategy=self.generation_strategy, + generation_strategy=self.standard_generation_strategy, trial_indices=trial_indices, use_model_predictions=use_model_predictions, ) @@ -1599,7 +1582,7 @@ def get_pareto_optimal_parameters( ) -> dict[int, tuple[TParameterization, TModelPredictArm]]: return self._get_pareto_optimal_parameters( experiment=self.experiment, - generation_strategy=self.generation_strategy, + generation_strategy=self.standard_generation_strategy, trial_indices=trial_indices, use_model_predictions=use_model_predictions, ) @@ -1613,7 +1596,7 @@ def get_hypervolume( ) -> float: return BestPointMixin._get_hypervolume( experiment=self.experiment, - generation_strategy=self.generation_strategy, + generation_strategy=self.standard_generation_strategy, optimization_config=optimization_config, trial_indices=trial_indices, use_model_predictions=use_model_predictions, @@ -1816,7 +1799,7 @@ def _gen_new_generator_run( else None ) with with_rng_seed(seed=self._random_seed): - return none_throws(self.generation_strategy).gen( + return none_throws(self.standard_generation_strategy).gen( experiment=self.experiment, n=n, pending_observations=self._get_pending_observation_features( diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 2acf5f21e1e..025afd1368f 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -8,8 +8,6 @@ from __future__ import annotations -import traceback - from collections.abc import Callable, Generator, Iterable, Mapping from copy import deepcopy from dataclasses import dataclass @@ -20,10 +18,6 @@ from typing import Any, cast, NamedTuple, Optional import ax.service.utils.early_stopping as early_stopping_utils -import pandas as pd -from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE -from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard -from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface @@ -57,6 +51,7 @@ from ax.modelbridge.base import ModelBridge from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.modelbridge_utils import get_fixed_features_from_experiment +from ax.service.utils.analysis_base import AnalysisBase from ax.service.utils.best_point_mixin import BestPointMixin from ax.service.utils.scheduler_options import SchedulerOptions, TrialType from ax.service.utils.with_db_settings_base import DBSettings, WithDBSettingsBase @@ -70,7 +65,6 @@ set_ax_logger_levels, ) from ax.utils.common.timeutils import current_timestamp_in_millis -from ax.utils.common.typeutils import checked_cast from pyre_extensions import assert_is_instance, none_throws @@ -151,7 +145,7 @@ def append(self, text: str) -> None: self.text += text -class Scheduler(WithDBSettingsBase, BestPointMixin): +class Scheduler(AnalysisBase, BestPointMixin): """Closed-loop manager class for Ax optimization. Attributes: @@ -168,8 +162,6 @@ class Scheduler(WithDBSettingsBase, BestPointMixin): been saved, as otherwise experiment state could get corrupted.** """ - experiment: Experiment - generation_strategy: GenerationStrategyInterface # pyre-fixme[24]: Generic type `LoggerAdapter` expects 1 type parameter. logger: LoggerAdapter # Mapping of form {short string identifier -> message to show in reported @@ -497,21 +489,6 @@ def runner(self) -> Runner: ) return runner - @property - def standard_generation_strategy(self) -> GenerationStrategy: - """Used for operations in the scheduler that can only be done with - and instance of ``GenerationStrategy``. - """ - gs = self.generation_strategy - if not isinstance(gs, GenerationStrategy): - raise NotImplementedError( - "This functionality is only supported with instances of " - "`GenerationStrategy` (one that uses `GenerationStrategy` " - "class) and not yet with other types of " - "`GenerationStrategyInterface`." - ) - return gs - def __repr__(self) -> str: """Short user-friendly string representation.""" if not hasattr(self, "experiment"): @@ -679,62 +656,6 @@ def run_all_trials( idle_callback=idle_callback, ) - def compute_analyses( - self, analyses: Iterable[Analysis] | None = None - ) -> list[AnalysisCard]: - """ - Compute Analyses for the Experiment and GenerationStrategy associated with this - Scheduler instance and save them to the DB if possible. If an Analysis fails to - compute (e.g. due to a missing metric), it will be skipped and a warning will - be logged. - - Args: - analyses: Analyses to compute. If None, the Scheduler will choose a set of - Analyses to compute based on the Experiment and GenerationStrategy. - """ - analyses = analyses if analyses is not None else self._choose_analyses() - - results = [ - analysis.compute_result( - experiment=self.experiment, generation_strategy=self.generation_strategy - ) - for analysis in analyses - ] - - # TODO Accumulate Es into their own card, perhaps via unwrap_or_else - cards = [result.unwrap() for result in results if result.is_ok()] - - for result in results: - if result.is_err(): - e = checked_cast(AnalysisE, result.err) - traceback_str = "".join( - traceback.format_exception( - type(result.err.exception), - e.exception, - e.exception.__traceback__, - ) - ) - cards.append( - MarkdownAnalysisCard( - name=e.analysis.name, - # It would be better if we could reliably compute the title - # without risking another error - title=f"{e.analysis.name} Error", - subtitle=f"An error occurred while computing {e.analysis}", - attributes=e.analysis.attributes, - blob=traceback_str, - df=pd.DataFrame(), - level=AnalysisCardLevel.DEBUG, - ) - ) - - self._save_analysis_cards_to_db_if_possible( - analysis_cards=cards, - experiment=self.experiment, - ) - - return cards - def run_trials_and_yield_results( self, max_trials: int, @@ -1882,14 +1803,6 @@ def _get_next_trials( trials.append(trial) return trials, None - def _choose_analyses(self) -> list[Analysis]: - """ - Choose Analyses to compute based on the Experiment, GenerationStrategy, etc. - """ - - # TODO Create a useful heuristic for choosing analyses - return [ParallelCoordinatesPlot()] - def _gen_new_trials_from_generation_strategy( self, num_trials: int, diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 4cf6825f326..f6cb1ec1686 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -504,7 +504,10 @@ def test_default_generation_strategy_continuous(self, _a, _b, _c, _d) -> None: """ ax_client = get_branin_optimization() self.assertEqual( - [s.model for s in none_throws(ax_client.generation_strategy)._steps], + [ + s.model + for s in none_throws(ax_client.standard_generation_strategy)._steps + ], [Models.SOBOL, Models.BOTORCH_MODULAR], ) with self.assertRaisesRegex(ValueError, ".* no trials"): @@ -719,7 +722,7 @@ def test_default_generation_strategy_continuous_for_moo( }, ) self.assertEqual( - [s.model for s in none_throws(ax_client.generation_strategy)._steps], + [s.model for s in ax_client.standard_generation_strategy._steps], [Models.SOBOL, Models.BOTORCH_MODULAR], ) with self.assertRaisesRegex(ValueError, ".* no trials"): @@ -782,7 +785,7 @@ def test_create_experiment(self) -> None: steps=[GenerationStep(model=Models.SOBOL, num_trials=30)] ) ) - with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): + with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"): ax_client.experiment ax_client.create_experiment( name="test_experiment", @@ -1019,7 +1022,7 @@ def test_create_single_objective_experiment_with_objectives_dict(self) -> None: steps=[GenerationStep(model=Models.SOBOL, num_trials=30)] ) ) - with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): + with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"): ax_client.experiment ax_client.create_experiment( name="test_experiment", @@ -1080,7 +1083,7 @@ def test_create_single_objective_experiment_with_objectives_dict(self) -> None: def test_create_experiment_with_metric_definitions(self) -> None: """Test basic experiment creation.""" ax_client = AxClient() - with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): + with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"): ax_client.experiment metric_definitions = { @@ -1347,7 +1350,7 @@ def test_create_moo_experiment(self) -> None: steps=[GenerationStep(model=Models.SOBOL, num_trials=30)] ) ) - with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): + with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"): ax_client.experiment ax_client.create_experiment( name="test_experiment", @@ -1581,10 +1584,9 @@ def test_keep_generating_without_data(self) -> None: {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], ) - self.assertFalse( - ax_client.generation_strategy._steps[0].enforce_num_trials, False - ) - self.assertFalse(ax_client.generation_strategy._steps[1].max_parallelism, None) + gs = ax_client.standard_generation_strategy + self.assertFalse(gs._steps[0].enforce_num_trials, False) + self.assertFalse(gs._steps[1].max_parallelism, None) for _ in range(10): parameterization, trial_index = ax_client.get_next_trial() @@ -2100,14 +2102,14 @@ def test_sqa_storage(self) -> None: # pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, U... raw_data=branin(*parameters.values()), ) - gs = ax_client.generation_strategy + gs = ax_client.standard_generation_strategy ax_client = AxClient(db_settings=db_settings) ax_client.load_experiment_from_database("test_experiment") # Some fields of the reloaded GS are not expected to be set (both will be # set during next model fitting call), so we unset them on the original GS as # well. gs._unset_non_persistent_state_fields() - ax_client.generation_strategy._unset_non_persistent_state_fields() + ax_client.standard_generation_strategy._unset_non_persistent_state_fields() self.assertEqual(gs, ax_client.generation_strategy) with self.assertRaises(ValueError): # Overwriting existing experiment. @@ -2461,8 +2463,9 @@ def helper_test_get_pareto_optimal_points( num_trials=20, outcome_constraints=outcome_constraints ) ax_client.fit_model() + gs = ax_client.standard_generation_strategy self.assertEqual( - ax_client.generation_strategy._curr.model_spec_to_gen_from.model_key, + gs._curr.model_spec_to_gen_from.model_key, "BoTorch", ) @@ -2487,7 +2490,7 @@ def helper_test_get_pareto_optimal_points( # This overwrites the `predict` call to return the original observations, # while testing the rest of the code as if we're using predictions. # pyre-fixme[16]: `Optional` has no attribute `model`. - model = ax_client.generation_strategy.model.model + model = ax_client.standard_generation_strategy.model.model ys = model.surrogate.training_data[0].Y with patch.object( model, "predict", return_value=(ys, torch.zeros(*ys.shape, ys.shape[-1])) @@ -2531,8 +2534,9 @@ def helper_test_get_pareto_optimal_points_from_sobol_step( ax_client, _ = get_branin_currin_optimization_with_N_sobol_trials( num_trials=20, minimize=minimize, outcome_constraints=outcome_constraints ) + gs = ax_client.standard_generation_strategy self.assertEqual( - ax_client.generation_strategy._curr.model_spec_to_gen_from.model_key, + gs._curr.model_spec_to_gen_from.model_key, "Sobol", ) @@ -2643,8 +2647,8 @@ def test_get_pareto_optimal_points_objective_threshold_inference( ax_client, _ = get_branin_currin_optimization_with_N_sobol_trials( num_trials=20, include_objective_thresholds=False ) - ax_client.generation_strategy._maybe_transition_to_next_node() - ax_client.generation_strategy._fit_current_model( + ax_client.standard_generation_strategy._maybe_transition_to_next_node() + ax_client.standard_generation_strategy._fit_current_model( data=ax_client.experiment.lookup_data() ) @@ -2855,7 +2859,8 @@ def test_with_hss(self) -> None: # Make sure we actually tried a Botorch iteration and all the transforms it # applies. self.assertEqual( - ax_client.generation_strategy._generator_runs[-1]._model_key, "BoTorch" + ax_client.standard_generation_strategy._generator_runs[-1]._model_key, + "BoTorch", ) self.assertEqual(len(ax_client.experiment.trials), 6) ax_client.attach_trial( @@ -2970,7 +2975,7 @@ def test_torch_device(self) -> None: torch_device=device, ) ax_client = get_branin_optimization(torch_device=device) - gpei_step_kwargs = ax_client.generation_strategy._steps[1].model_kwargs + gpei_step_kwargs = ax_client.standard_generation_strategy._steps[1].model_kwargs self.assertEqual(gpei_step_kwargs["torch_device"], device) def test_repr_function( @@ -2999,7 +3004,7 @@ def test_gen_fixed_features(self) -> None: name="fixed_features", ) with mock.patch.object( - GenerationStrategy, "gen", wraps=ax_client.generation_strategy.gen + GenerationStrategy, "gen", wraps=ax_client.standard_generation_strategy.gen ) as mock_gen: with self.subTest("fixed_features is None"): params, idx = ax_client.get_next_trial() diff --git a/ax/service/utils/analysis_base.py b/ax/service/utils/analysis_base.py new file mode 100644 index 00000000000..3f79c3c0c9b --- /dev/null +++ b/ax/service/utils/analysis_base.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import traceback +from typing import Iterable + +import pandas as pd + +from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE +from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard +from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot +from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.service.utils.with_db_settings_base import WithDBSettingsBase +from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws + + +class AnalysisBase(WithDBSettingsBase): + """ + Base class for analysis functionality shared between AxClient and Scheduler. + It also manages the experiment and generation strategy associated with the + instance. + """ + + # pyre-fixme[13]: Attribute `experiment` is declared in class + # `AnalysisBase` to have type `Experiment` but is never initialized + _experiment: Experiment | None + # pyre-fixme[13]: Attribute `generation_strategy` is declared in class + # `AnalysisBase` to have type `GenerationStrategyInterface` but + # is never initialized + _generation_strategy: GenerationStrategyInterface | None + + def _choose_analyses(self) -> list[Analysis]: + """ + Choose Analyses to compute based on the Experiment, GenerationStrategy, etc. + """ + + # TODO Create a useful heuristic for choosing analyses + return [ParallelCoordinatesPlot()] + + def compute_analyses( + self, analyses: Iterable[Analysis] | None = None + ) -> list[AnalysisCard]: + """ + Compute Analyses for the Experiment and GenerationStrategy associated with this + Scheduler instance and save them to the DB if possible. If an Analysis fails to + compute (e.g. due to a missing metric), it will be skipped and a warning will + be logged. + + Args: + analyses: Analyses to compute. If None, the Scheduler will choose a set of + Analyses to compute based on the Experiment and GenerationStrategy. + """ + analyses = analyses if analyses is not None else self._choose_analyses() + + results = [ + analysis.compute_result( + experiment=self.experiment, + generation_strategy=self.generation_strategy, + ) + for analysis in analyses + ] + + # TODO Accumulate Es into their own card, perhaps via unwrap_or_else + cards = [result.unwrap() for result in results if result.is_ok()] + + for result in results: + if result.is_err(): + e = checked_cast(AnalysisE, result.err) + traceback_str = "".join( + traceback.format_exception( + type(result.err.exception), + e.exception, + e.exception.__traceback__, + ) + ) + cards.append( + MarkdownAnalysisCard( + name=e.analysis.name, + # It would be better if we could reliably compute the title + # without risking another error + title=f"{e.analysis.name} Error", + subtitle=f"An error occurred while computing {e.analysis}", + attributes=e.analysis.attributes, + blob=traceback_str, + df=pd.DataFrame(), + level=AnalysisCardLevel.DEBUG, + ) + ) + + self._save_analysis_cards_to_db_if_possible( + analysis_cards=cards, + experiment=self.experiment, + ) + + return cards + + @property + def experiment(self) -> Experiment: + """Returns the experiment set on this instance.""" + return none_throws( + self._experiment, + ( + f"Experiment not set on {self.__class__.__name__}. Must first " + "call load_experiment or create_experiment to use handler functions." + ), + ) + + @experiment.setter + def experiment(self, experiment: Experiment) -> None: + """Sets the experiment on this instance.""" + self._experiment = experiment + + @property + def generation_strategy(self) -> GenerationStrategyInterface: + """Returns the generation strategy, set on this experiment.""" + return none_throws( + self._generation_strategy, + "No generation strategy has been set on this optimization yet.", + ) + + @generation_strategy.setter + def generation_strategy( + self, generation_strategy: GenerationStrategyInterface + ) -> None: + """Sets the generation strategy on this instance.""" + self._generation_strategy = generation_strategy + + @property + def standard_generation_strategy(self) -> GenerationStrategy: + """Used for operations in the scheduler that can only be done with + and instance of ``GenerationStrategy``. + """ + gs = self.generation_strategy + if not isinstance(gs, GenerationStrategy): + raise NotImplementedError( + "This functionality is only supported with instances of " + "`GenerationStrategy` (one that uses `GenerationStrategy` " + "class) and not yet with other types of " + "`GenerationStrategyInterface`." + ) + return gs diff --git a/sphinx/source/service.rst b/sphinx/source/service.rst index fdbb67fb320..e66454a97bb 100644 --- a/sphinx/source/service.rst +++ b/sphinx/source/service.rst @@ -49,6 +49,15 @@ Scheduler Utils ----- +Analysis +~~~~~~~~ + +.. automodule:: ax.service.utils.analysis_base + :members: + :undoc-members: + :show-inheritance: + + Best Point Identification ~~~~~~~~~~~~~~~~~~~~~~~~~