diff --git a/ax/analysis/analysis.py b/ax/analysis/analysis.py index c738b1d51e2..10fcff8e043 100644 --- a/ax/analysis/analysis.py +++ b/ax/analysis/analysis.py @@ -15,7 +15,7 @@ import pandas as pd from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.utils.common.base import Base from ax.utils.common.logger import get_logger from ax.utils.common.result import Err, ExceptionE, Ok, Result @@ -123,7 +123,7 @@ class Analysis(Protocol): def compute( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> AnalysisCard: # Note: when implementing compute always prefer experiment.lookup_data() to # experiment.fetch_data() to avoid unintential data fetching within the report @@ -133,7 +133,7 @@ def compute( def compute_result( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> Result[AnalysisCard, AnalysisE]: """ Utility method to compute an AnalysisCard as a Result. This can be useful for diff --git a/ax/analysis/healthcheck/can_generate_candidates.py b/ax/analysis/healthcheck/can_generate_candidates.py index 5915aa77dc4..30eb127913e 100644 --- a/ax/analysis/healthcheck/can_generate_candidates.py +++ b/ax/analysis/healthcheck/can_generate_candidates.py @@ -18,7 +18,7 @@ HealthcheckStatus, ) from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.generation_strategy.generation_strategy import GenerationStrategy from pyre_extensions import none_throws @@ -47,7 +47,7 @@ def __init__( def compute( self, experiment: Optional[Experiment] = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> HealthcheckAnalysisCard: status = HealthcheckStatus.PASS subtitle = self.reason diff --git a/ax/analysis/healthcheck/constraints_feasibility.py b/ax/analysis/healthcheck/constraints_feasibility.py index 552a0843aaa..82399c186b0 100644 --- a/ax/analysis/healthcheck/constraints_feasibility.py +++ b/ax/analysis/healthcheck/constraints_feasibility.py @@ -20,7 +20,6 @@ from ax.analysis.plotly.arm_effects.utils import get_predictions_by_arm from ax.analysis.plotly.utils import is_predictive from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.optimization_config import OptimizationConfig from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy @@ -48,7 +47,7 @@ def __init__(self, prob_threshold: float = 0.95) -> None: def compute( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> HealthcheckAnalysisCard: r""" Compute the feasibility of the constraints for the experiment. diff --git a/ax/analysis/healthcheck/healthcheck_analysis.py b/ax/analysis/healthcheck/healthcheck_analysis.py index 56e72a2862e..41345759099 100644 --- a/ax/analysis/healthcheck/healthcheck_analysis.py +++ b/ax/analysis/healthcheck/healthcheck_analysis.py @@ -9,7 +9,7 @@ from ax.analysis.analysis import Analysis, AnalysisCard from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.generation_strategy.generation_strategy import GenerationStrategy class HealthcheckStatus(IntEnum): @@ -33,5 +33,5 @@ class HealthcheckAnalysis(Analysis): def compute( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> HealthcheckAnalysisCard: ... diff --git a/ax/analysis/healthcheck/regression_analysis.py b/ax/analysis/healthcheck/regression_analysis.py index 044f6862b4c..43f456e5a66 100644 --- a/ax/analysis/healthcheck/regression_analysis.py +++ b/ax/analysis/healthcheck/regression_analysis.py @@ -19,8 +19,8 @@ detect_regressions_by_trial, ) from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy from pyre_extensions import none_throws @@ -49,7 +49,7 @@ def __init__(self, prob_threshold: float = 0.95) -> None: def compute( self, experiment: Experiment | None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> HealthcheckAnalysisCard: r""" Detect the regressing arms for all trials that have data. diff --git a/ax/analysis/healthcheck/search_space_analysis.py b/ax/analysis/healthcheck/search_space_analysis.py index ccb4d883606..1209ea80953 100644 --- a/ax/analysis/healthcheck/search_space_analysis.py +++ b/ax/analysis/healthcheck/search_space_analysis.py @@ -18,13 +18,13 @@ HealthcheckStatus, ) from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.parameter import ChoiceParameter, Parameter, RangeParameter from ax.core.parameter_constraint import ParameterConstraint from ax.core.search_space import SearchSpace from ax.core.types import TParameterization from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy from pyre_extensions import assert_is_instance @@ -54,7 +54,7 @@ def __init__( def compute( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> HealthcheckAnalysisCard: r""" Args: diff --git a/ax/analysis/healthcheck/should_generate_candidates.py b/ax/analysis/healthcheck/should_generate_candidates.py index 43ade80af94..3f359f64408 100644 --- a/ax/analysis/healthcheck/should_generate_candidates.py +++ b/ax/analysis/healthcheck/should_generate_candidates.py @@ -16,7 +16,7 @@ HealthcheckStatus, ) from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.generation_strategy.generation_strategy import GenerationStrategy class ShouldGenerateCandidates(HealthcheckAnalysis): @@ -33,7 +33,7 @@ def __init__( def compute( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> HealthcheckAnalysisCard: status = ( HealthcheckStatus.PASS diff --git a/ax/analysis/markdown/markdown_analysis.py b/ax/analysis/markdown/markdown_analysis.py index c8958d59838..f8a7dd57b4a 100644 --- a/ax/analysis/markdown/markdown_analysis.py +++ b/ax/analysis/markdown/markdown_analysis.py @@ -11,7 +11,7 @@ import pandas as pd from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.generation_strategy.generation_strategy import GenerationStrategy from IPython.display import display, Markdown @@ -37,7 +37,7 @@ class MarkdownAnalysis(Analysis): def compute( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> MarkdownAnalysisCard: ... def _create_markdown_analysis_card( diff --git a/ax/analysis/plotly/arm_effects/insample_effects.py b/ax/analysis/plotly/arm_effects/insample_effects.py index 666b6ba9a78..bf2234c936a 100644 --- a/ax/analysis/plotly/arm_effects/insample_effects.py +++ b/ax/analysis/plotly/arm_effects/insample_effects.py @@ -18,7 +18,6 @@ from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.utils import is_predictive from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.outcome_constraint import OutcomeConstraint from ax.exceptions.core import DataRequiredError, UserInputError @@ -77,7 +76,7 @@ def __init__( def compute( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("InSampleEffectsPlot requires an Experiment.") @@ -174,7 +173,7 @@ def _get_max_observed_trial_index(model: Adapter) -> int | None: def _get_model( experiment: Experiment, - generation_strategy: GenerationStrategyInterface | None, + generation_strategy: GenerationStrategy | None, use_modeled_effects: bool, trial_index: int, metric_name: str, diff --git a/ax/analysis/plotly/arm_effects/predicted_effects.py b/ax/analysis/plotly/arm_effects/predicted_effects.py index 4e04e4e32e8..77e04d431ca 100644 --- a/ax/analysis/plotly/arm_effects/predicted_effects.py +++ b/ax/analysis/plotly/arm_effects/predicted_effects.py @@ -21,7 +21,6 @@ from ax.core import OutcomeConstraint from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.base import Adapter @@ -68,7 +67,7 @@ def __init__(self, metric_name: str) -> None: def compute( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("PredictedEffectsPlot requires an Experiment.") diff --git a/ax/analysis/plotly/cross_validation.py b/ax/analysis/plotly/cross_validation.py index f210dd1f438..c9dda53ede7 100644 --- a/ax/analysis/plotly/cross_validation.py +++ b/ax/analysis/plotly/cross_validation.py @@ -12,12 +12,11 @@ from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.utils import select_metric from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.cross_validation import cross_validate from plotly import express as px, graph_objects as go -from pyre_extensions import assert_is_instance, none_throws +from pyre_extensions import none_throws class CrossValidationPlot(PlotlyAnalysis): @@ -82,7 +81,7 @@ def __init__( def compute( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> PlotlyAnalysisCard: if generation_strategy is None: raise UserInputError("CrossValidation requires a GenerationStrategy") @@ -92,11 +91,7 @@ def compute( ) df = _prepare_data( - # CrossValidationPlot requires a native Ax GenerationStrategy and cannot be - # used with a GenerationStrategyInterface. - generation_strategy=assert_is_instance( - generation_strategy, GenerationStrategy - ), + generation_strategy=generation_strategy, metric_name=metric_name, folds=self.folds, untransform=self.untransform, diff --git a/ax/analysis/plotly/interaction.py b/ax/analysis/plotly/interaction.py index 164e5f91ee6..cd29d38ae71 100644 --- a/ax/analysis/plotly/interaction.py +++ b/ax/analysis/plotly/interaction.py @@ -25,8 +25,8 @@ from ax.analysis.plotly.surface.utils import is_axis_log_scale from ax.analysis.plotly.utils import select_metric from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.registry import Generators from ax.modelbridge.torch import TorchAdapter from ax.models.torch.botorch_modular.surrogate import Surrogate @@ -89,7 +89,7 @@ def __init__( def compute( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> PlotlyAnalysisCard: """ Compute Sobol index sensitivity for one metric of an experiment. Sensitivity diff --git a/ax/analysis/plotly/parallel_coordinates.py b/ax/analysis/plotly/parallel_coordinates.py index 634d9978166..f55ba2e8cf8 100644 --- a/ax/analysis/plotly/parallel_coordinates.py +++ b/ax/analysis/plotly/parallel_coordinates.py @@ -14,8 +14,8 @@ from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.utils import select_metric from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy from plotly import graph_objects as go @@ -45,7 +45,7 @@ def __init__(self, metric_name: str | None = None) -> None: def compute( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("ParallelCoordinatesPlot requires an Experiment") diff --git a/ax/analysis/plotly/plotly_analysis.py b/ax/analysis/plotly/plotly_analysis.py index e45a28310b9..e957065b19f 100644 --- a/ax/analysis/plotly/plotly_analysis.py +++ b/ax/analysis/plotly/plotly_analysis.py @@ -9,7 +9,7 @@ import pandas as pd from ax.analysis.analysis import Analysis, AnalysisCard from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.generation_strategy.generation_strategy import GenerationStrategy from IPython.display import display, Markdown from plotly import graph_objects as go, io as pio @@ -37,7 +37,7 @@ class PlotlyAnalysis(Analysis): def compute( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> PlotlyAnalysisCard: ... def _create_plotly_analysis_card( diff --git a/ax/analysis/plotly/scatter.py b/ax/analysis/plotly/scatter.py index b40adebfc5b..e058b195202 100644 --- a/ax/analysis/plotly/scatter.py +++ b/ax/analysis/plotly/scatter.py @@ -12,8 +12,8 @@ from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.exceptions.core import DataRequiredError, UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy from plotly import express as px, graph_objects as go @@ -51,7 +51,7 @@ def __init__( def compute( self, experiment: Optional[Experiment] = None, - generation_strategy: Optional[GenerationStrategyInterface] = None, + generation_strategy: Optional[GenerationStrategy] = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("ScatterPlot requires an Experiment") diff --git a/ax/analysis/plotly/surface/contour.py b/ax/analysis/plotly/surface/contour.py index 47139c0ac56..216766c159b 100644 --- a/ax/analysis/plotly/surface/contour.py +++ b/ax/analysis/plotly/surface/contour.py @@ -19,7 +19,6 @@ ) from ax.analysis.plotly.utils import select_metric from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.observation import ObservationFeatures from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy @@ -63,7 +62,7 @@ def __init__( def compute( self, experiment: Optional[Experiment] = None, - generation_strategy: Optional[GenerationStrategyInterface] = None, + generation_strategy: Optional[GenerationStrategy] = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("ContourPlot requires an Experiment") diff --git a/ax/analysis/plotly/surface/slice.py b/ax/analysis/plotly/surface/slice.py index b031b022a94..b69754979be 100644 --- a/ax/analysis/plotly/surface/slice.py +++ b/ax/analysis/plotly/surface/slice.py @@ -19,7 +19,6 @@ ) from ax.analysis.plotly.utils import select_metric from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.observation import ObservationFeatures from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy @@ -57,7 +56,7 @@ def __init__( def compute( self, experiment: Optional[Experiment] = None, - generation_strategy: Optional[GenerationStrategyInterface] = None, + generation_strategy: Optional[GenerationStrategy] = None, ) -> PlotlyAnalysisCard: if experiment is None: raise UserInputError("SlicePlot requires an Experiment") diff --git a/ax/analysis/summary.py b/ax/analysis/summary.py index 66badbc55e5..da0958e47b0 100644 --- a/ax/analysis/summary.py +++ b/ax/analysis/summary.py @@ -7,8 +7,8 @@ from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy class Summary(Analysis): @@ -36,7 +36,7 @@ def __init__(self, omit_empty_columns: bool = True) -> None: def compute( self, experiment: Experiment | None = None, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> AnalysisCard: if experiment is None: raise UserInputError("`Summary` analysis requires an `Experiment` input") diff --git a/ax/core/generation_strategy_interface.py b/ax/core/generation_strategy_interface.py deleted file mode 100644 index 2486cd4756e..00000000000 --- a/ax/core/generation_strategy_interface.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from __future__ import annotations - -from abc import ABC, abstractmethod - -from typing import Any - -from ax.core.data import Data -from ax.core.experiment import Experiment -from ax.core.generator_run import GeneratorRun -from ax.core.observation import ObservationFeatures -from ax.exceptions.core import AxError, UnsupportedError -from ax.utils.common.base import Base -from pyre_extensions import none_throws - - -class GenerationStrategyInterface(ABC, Base): - """Interface for all generation strategies: standard Ax - ``GenerationStrategy``, as well as non-standard (e.g. remote, external) - generation strategies. - - NOTE: Currently in Beta; please do not use without discussion with the Ax - developers. - """ - - _name: str - # Experiment, for which this generation strategy has generated trials, if - # it exists. - _experiment: Experiment | None = None - - # Constant for default number of arms to generate if `n` is not specified in - # `gen` call and "total_concurrent_arms" is not set in experiment properties. - DEFAULT_N: int = 1 - - def __init__(self, name: str) -> None: - self._name = name - - @abstractmethod - def gen_for_multiple_trials_with_multiple_models( - self, - experiment: Experiment, - data: Data | None = None, - pending_observations: dict[str, list[ObservationFeatures]] | None = None, - n: int | None = None, - num_trials: int = 1, - arms_per_node: dict[str, int] | None = None, - ) -> list[list[GeneratorRun]]: - """Produce ``GeneratorRun``-s for multiple trials at once with the possibility - of joining ``GeneratorRun``-s from multiple models into one ``BatchTrial``. - - Args: - experiment: ``Experiment``, for which the generation strategy is producing - a new generator run in the course of ``gen``, and to which that - generator run will be added as trial(s). Information stored on the - experiment (e.g., trial statuses) is used to determine which model - will be used to produce the generator run returned from this method. - data: Optional data to be passed to the underlying model's ``gen``, which - is called within this method and actually produces the resulting - generator run. By default, data is all data on the ``experiment``. - pending_observations: A map from metric name to pending - observations for that metric, used by some models to avoid - resuggesting points that are currently being evaluated. - n: Integer representing how many total arms should be in the generator - runs produced by this method. NOTE: Some underlying models may ignore - the `n` and produce a model-determined number of arms. In that - case this method will also output generator runs with number of - arms that can differ from `n`. - num_trials: Number of trials to generate generator runs for in this call. - If not provided, defaults to 1. - arms_per_node: An optional map from node name to the number of arms to - generate from that node. If not provided, will default to the number - of arms specified in the node's ``InputConstructors`` or n if no - ``InputConstructors`` are defined on the node. We expect either n or - arms_per_node to be provided, but not both, and this is an advanced - argument that should only be used by advanced users. - - Returns: - A list of lists of ``GeneratorRun``-s. Each outer list item represents - a ``(Batch)Trial`` being suggested, with a list of ``GeneratorRun``-s for - that trial. - """ - # When implementing your subclass' override for this method, don't forget - # to consider using "pending points", corresponding to arms in trials that - # are currently running / being evaluated/ - ... - - def _gen_multiple( - self, - experiment: Experiment, - num_generator_runs: int, - data: Data | None = None, - n: int = 1, - pending_observations: dict[str, list[ObservationFeatures]] | None = None, - **model_gen_kwargs: Any, - ) -> list[GeneratorRun]: - """Produce multiple generator runs at once, to be made into multiple - trials on the experiment. - - NOTE: This is used to ensure that maximum parallelism and number - of trials per node are not violated when producing many generator - runs from this generation strategy in a row. Without this function, - if one generates multiple generator runs without first making any - of them into running trials, generation strategy cannot enforce that it only - produces as many generator runs as are allowed by the parallelism - limit and the limit on number of trials in current node. - - Args: - experiment: Experiment, for which the generation strategy is producing - a new generator run in the course of `gen`, and to which that - generator run will be added as trial(s). Information stored on the - experiment (e.g., trial statuses) is used to determine which model - will be used to produce the generator run returned from this method. - data: Optional data to be passed to the underlying model's `gen`, which - is called within this method and actually produces the resulting - generator run. By default, data is all data on the `experiment`. - n: Integer representing how many arms should be in the generator run - produced by this method. NOTE: Some underlying models may ignore - the ``n`` and produce a model-determined number of arms. In that - case this method will also output a generator run with number of - arms that can differ from ``n``. - pending_observations: A map from metric name to pending - observations for that metric, used by some models to avoid - resuggesting points that are currently being evaluated. - model_gen_kwargs: Keyword arguments that are passed through to - ``GenerationNode.gen``, which will pass them through to - ``GeneratorSpec.gen``, which will pass them to ``Adapter.gen``. - """ - ... - - @abstractmethod - def clone_reset(self) -> GenerationStrategyInterface: - """Returns a clone of this generation strategy with all state reset.""" - ... - - @property - def name(self) -> str: - """Name of this generation strategy.""" - return self._name - - @property - def experiment(self) -> Experiment: - """Experiment, currently set on this generation strategy.""" - if self._experiment is None: - raise AxError("No experiment set on generation strategy.") - return none_throws(self._experiment) - - @experiment.setter - def experiment(self, experiment: Experiment) -> None: - """If there is an experiment set on this generation strategy as the - experiment it has been generating generator runs for, check if the - experiment passed in is the same as the one saved and log an information - statement if its not. Set the new experiment on this generation strategy. - """ - if self._experiment is not None and experiment._name != self.experiment._name: - raise UnsupportedError( - "This generation strategy has been used for experiment " - f"{self.experiment._name} so far; cannot reset experiment" - f" to {experiment._name}. If this is a new experiment, " - "a new generation strategy should be created instead." - ) - self._experiment = experiment diff --git a/ax/core/tests/test_generation_strategy_interface.py b/ax/core/tests/test_generation_strategy_interface.py deleted file mode 100644 index cd8432b3e59..00000000000 --- a/ax/core/tests/test_generation_strategy_interface.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - - -from ax.core.data import Data -from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface -from ax.core.generator_run import GeneratorRun -from ax.core.observation import ObservationFeatures -from ax.exceptions.core import AxError, UnsupportedError -from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import get_experiment, SpecialGenerationStrategy - - -class MyGSI(GenerationStrategyInterface): - def gen_for_multiple_trials_with_multiple_models( - self, - experiment: Experiment, - data: Data | None = None, - pending_observations: dict[str, list[ObservationFeatures]] | None = None, - n: int | None = None, - num_trials: int = 1, - arms_per_node: dict[str, int] | None = None, - ) -> list[list[GeneratorRun]]: - raise NotImplementedError - - def clone_reset(self) -> "MyGSI": - raise NotImplementedError - - -class TestGenerationStrategyInterface(TestCase): - def setUp(self) -> None: - super().setUp() - self.exp = get_experiment() - self.gsi = MyGSI(name="my_GSI") - self.special_gsi = SpecialGenerationStrategy() - - def test_constructor(self) -> None: - with self.assertRaisesRegex(TypeError, ".* abstract"): - GenerationStrategyInterface(name="my_GSI") # pyre-ignore[45] - self.assertEqual(self.gsi.name, "my_GSI") - - def test_abstract(self) -> None: - with self.assertRaises(NotImplementedError): - self.gsi.gen_for_multiple_trials_with_multiple_models(experiment=self.exp) - - with self.assertRaises(NotImplementedError): - self.gsi.clone_reset() - - def test_experiment(self) -> None: - with self.assertRaisesRegex(AxError, "No experiment"): - self.gsi.experiment - self.gsi.experiment = self.exp - exp_2 = get_experiment() - exp_2.name = "exp_2" - with self.assertRaisesRegex(UnsupportedError, "has been used for"): - self.gsi.experiment = exp_2 diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index dcea426bc6e..23815afc929 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -19,7 +19,6 @@ from ax.core.base_trial import TrialStatus from ax.core.data import Data from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.utils import extend_pending_observations, extract_pending_observations @@ -35,6 +34,7 @@ from ax.generation_strategy.model_spec import FactoryFunctionGeneratorSpec from ax.generation_strategy.transition_criterion import TrialBasedCriterion from ax.modelbridge.base import Adapter +from ax.utils.common.base import Base from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import assert_is_instance_list from pyre_extensions import none_throws @@ -70,7 +70,7 @@ def impl(self: GenerationStrategy, *args: list[Any], **kwargs: dict[str, Any]) - return impl -class GenerationStrategy(GenerationStrategyInterface): +class GenerationStrategy(Base): """GenerationStrategy describes which model should be used to generate new points for which trials, enabling and automating use of different models throughout the optimization process. For instance, it allows to use one @@ -93,6 +93,8 @@ class GenerationStrategy(GenerationStrategyInterface): strategy's name will be names of its nodes' models joined with '+'. """ + DEFAULT_N: int = 1 + _nodes: list[GenerationNode] _curr: GenerationNode # Current node in the strategy. # Whether all models in this GS are in Generators registry enum. @@ -102,6 +104,7 @@ class GenerationStrategy(GenerationStrategyInterface): _generator_runs: list[GeneratorRun] # Experiment, for which this generation strategy has generated trials, if # it exists. + _name: str _experiment: Experiment | None = None _model: Adapter | None = None # Current model. @@ -149,7 +152,7 @@ def __init__( self._generator_runs = [] # Set name to an explicit value ahead of time to avoid # adding properties during equality checks - super().__init__(name=name or self._make_default_name()) + self._name = name or self._make_default_name() @property def is_node_based(self) -> bool: @@ -179,6 +182,11 @@ def name(self, name: str) -> None: """Set generation strategy name.""" self._name = name + @property + def name(self) -> str: + """Name of this generation strategy.""" + return self._name + @property @step_based_gs_only def model_transitions(self) -> list[int]: diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 76d92b5c221..9f360c6463f 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -23,7 +23,6 @@ from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import DataType, Experiment from ax.core.formatting_utils import data_and_evaluations_from_raw_data -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.map_data import MapData from ax.core.map_metric import MapMetric @@ -1783,7 +1782,7 @@ def _set_generation_strategy( def _save_generation_strategy_to_db_if_possible( self, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> bool: return super()._save_generation_strategy_to_db_if_possible( generation_strategy=generation_strategy or self.generation_strategy, diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 7e81c06ca46..bd9f85c8c70 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -20,7 +20,6 @@ import ax.service.utils.early_stopping as early_stopping_utils from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric, MetricFetchE, MetricFetchResult from ax.core.multi_type_experiment import ( @@ -163,7 +162,7 @@ class Scheduler(AnalysisBase, BestPointMixin): """ experiment: Experiment - generation_strategy: GenerationStrategyInterface + generation_strategy: GenerationStrategy # pyre-fixme[24]: Generic type `LoggerAdapter` expects 1 type parameter. logger: LoggerAdapter # Mapping of form {short string identifier -> message to show in reported @@ -212,7 +211,7 @@ class Scheduler(AnalysisBase, BestPointMixin): def __init__( self, experiment: Experiment, - generation_strategy: GenerationStrategyInterface, + generation_strategy: GenerationStrategy, options: SchedulerOptions, db_settings: Optional[DBSettings] = None, _skip_experiment_save: bool = False, @@ -228,7 +227,7 @@ def __init__( if not isinstance(experiment, Experiment): raise TypeError("{experiment} is not an Ax experiment.") - if not isinstance(generation_strategy, GenerationStrategyInterface): + if not isinstance(generation_strategy, GenerationStrategy): raise TypeError("{generation_strategy} is not a generation strategy.") # Initialize storage layer for the scheduler. diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index becba865fa3..4c2f40a6516 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -15,7 +15,7 @@ from random import randint from tempfile import NamedTemporaryFile from typing import Any, Callable, cast, Optional -from unittest.mock import call, Mock, patch, PropertyMock +from unittest.mock import call, Mock, patch import pandas as pd from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot @@ -24,7 +24,6 @@ from ax.core.batch_trial import BatchTrial from ax.core.data import Data from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.map_data import MapData from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment @@ -93,7 +92,6 @@ get_generator_run, get_online_sobol_mbm_generation_strategy, get_sobol, - SpecialGenerationStrategy, ) from ax.utils.testing.mock import mock_botorch_optimize from ax.utils.testing.modeling_stubs import get_generation_strategy @@ -291,9 +289,6 @@ class AxSchedulerTestCase(TestCase): inherit and run its associated tests. """ - GENERATION_STRATEGY_INTERFACE_CLASS: type[GenerationStrategyInterface] = ( - GenerationStrategy - ) # TODO[@mgarrard]: Change this to `str(GenerationStrategy.__module__)` # once we are no longer splitting which `GS.gen` to call into based on # `Trial` vs. `BatchTrial` @@ -385,7 +380,7 @@ def _get_generation_strategy_strategy_for_test( self, experiment: Experiment, generation_strategy: GenerationStrategy | None = None, - ) -> GenerationStrategyInterface: + ) -> GenerationStrategy: return none_throws(generation_strategy) @property @@ -1661,7 +1656,7 @@ def test_optimization_complete(self) -> None: db_settings=self.db_settings_if_always_needed, ) with patch.object( - self.GENERATION_STRATEGY_INTERFACE_CLASS, + GenerationStrategy, "_gen_multiple", side_effect=OptimizationComplete("test error"), ) as mock_gen_multiple: @@ -2203,19 +2198,6 @@ def _helper_path_that_refits_the_model_if_it_is_not_already_initialized( self, scheduler: Scheduler, ) -> None: - with patch.object( - self.GENERATION_STRATEGY_INTERFACE_CLASS, - "model", - new_callable=PropertyMock, - return_value=None, - ), patch.object( - self.GENERATION_STRATEGY_INTERFACE_CLASS, - "_fit_current_model", - wraps=scheduler.standard_generation_strategy._fit_current_model, - ) as fit_model: - get_fitted_model_bridge(scheduler) - fit_model.assert_called_once() - # testing get_fitted_model_bridge model_bridge = get_fitted_model_bridge(scheduler) @@ -2260,21 +2242,6 @@ def test_standard_generation_strategy(self) -> None: ) self.assertEqual(scheduler.standard_generation_strategy, self.sobol_MBM_GS) - with self.subTest("with a `SpecialGenerationStrategy`"): - scheduler = Scheduler( - experiment=self.branin_experiment, - generation_strategy=SpecialGenerationStrategy(), - options=SchedulerOptions( - **self.scheduler_options_kwargs, - ), - db_settings=self.db_settings_if_always_needed, - ) - with self.assertRaisesRegex( - NotImplementedError, - "only supported with instances of `GenerationStrategy`", - ): - scheduler.standard_generation_strategy - def test_get_improvement_over_baseline(self) -> None: n_total_trials = 8 diff --git a/ax/service/tests/test_with_db_settings_base.py b/ax/service/tests/test_with_db_settings_base.py index 4fa00749fbf..36cedc50634 100644 --- a/ax/service/tests/test_with_db_settings_base.py +++ b/ax/service/tests/test_with_db_settings_base.py @@ -30,11 +30,7 @@ ) from ax.storage.sqa_store.structs import DBSettings from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import ( - get_experiment, - get_generator_run, - SpecialGenerationStrategy, -) +from ax.utils.testing.core_stubs import get_experiment, get_generator_run from ax.utils.testing.modeling_stubs import get_generation_strategy @@ -134,13 +130,6 @@ def test_save_generation_strategy(self) -> None: self.assertIsNotNone(loaded_gs) self.assertEqual(loaded_gs.name, generation_strategy.name) - def test_save_non_standard_generation_strategy(self) -> None: - generation_strategy = SpecialGenerationStrategy() - saved = self.with_db_settings._save_generation_strategy_to_db_if_possible( - generation_strategy - ) - self.assertFalse(saved) - def test_save_load_experiment_and_generation_strategy(self) -> None: experiment, generation_strategy = self.init_experiment_and_generation_strategy( save_generation_strategy=False @@ -183,16 +172,6 @@ def test_update_generation_strategy(self) -> None: self.assertIsNotNone(generator_run.db_id) self.assertIsNotNone(generator_run.arms[0].db_id) - def test_update_non_standard_generation_strategy(self) -> None: - generation_strategy = SpecialGenerationStrategy() - generator_run = get_generator_run() - saved = self.with_db_settings._update_generation_strategy_in_db_if_possible( - generation_strategy, [generator_run] - ) - self.assertFalse(saved) - self.assertIsNone(generator_run.db_id) - self.assertIsNone(generator_run.arms[0].db_id) - @patch(f"{WithDBSettingsBase.__module__}.STORAGE_MINI_BATCH_SIZE", 2) def test_update_generation_strategy_mini_batches(self) -> None: _, generation_strategy = self.init_experiment_and_generation_strategy() diff --git a/ax/service/utils/analysis_base.py b/ax/service/utils/analysis_base.py index e55d775ff81..c2c0b612b22 100644 --- a/ax/service/utils/analysis_base.py +++ b/ax/service/utils/analysis_base.py @@ -13,7 +13,7 @@ from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.service.utils.with_db_settings_base import WithDBSettingsBase from pyre_extensions import assert_is_instance @@ -27,9 +27,9 @@ class AnalysisBase(WithDBSettingsBase): # `AnalysisBase` to have type `Experiment` but is never initialized experiment: Experiment # pyre-fixme[13]: Attribute `generation_strategy` is declared in class - # `AnalysisBase` to have type `GenerationStrategyInterface` but + # `AnalysisBase` to have type `GenerationStrategy` but # is never initialized - generation_strategy: GenerationStrategyInterface + generation_strategy: GenerationStrategy def _choose_analyses(self) -> list[Analysis]: """ diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index 31010da5cb6..f00a0ad25b2 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -17,7 +17,6 @@ from ax.core.base_trial import BaseTrial from ax.core.experiment import Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.runner import Runner from ax.exceptions.core import ( @@ -158,7 +157,7 @@ def _get_experiment_and_generation_strategy_db_id( return exp_id, gs_id def _maybe_save_experiment_and_generation_strategy( - self, experiment: Experiment, generation_strategy: GenerationStrategyInterface + self, experiment: Experiment, generation_strategy: GenerationStrategy ) -> tuple[bool, bool]: """If DB settings are set on this `WithDBSettingsBase` instance, checks whether given experiment and generation strategy are already saved and @@ -294,7 +293,7 @@ def _save_or_update_trials_and_generation_strategy_if_possible( self, experiment: Experiment, trials: list[BaseTrial], - generation_strategy: GenerationStrategyInterface, + generation_strategy: GenerationStrategy, new_generator_runs: list[GeneratorRun], reduce_state_generator_runs: bool = False, ) -> None: @@ -376,14 +375,14 @@ def _save_or_update_trials_in_db_if_possible( def _save_generation_strategy_to_db_if_possible( self, - generation_strategy: GenerationStrategyInterface | None = None, + generation_strategy: GenerationStrategy | None = None, ) -> bool: """Saves given generation strategy if DB settings are set on this `WithDBSettingsBase` instance and the generation strategy is an instance of `GenerationStrategy`. Args: - generation_strategy: GenerationStrategyInterface to update in DB. + generation_strategy: GenerationStrategy to update in DB. For now, only instances of GenerationStrategy will be updated. Otherwise, this function is a no-op. @@ -393,19 +392,18 @@ def _save_generation_strategy_to_db_if_possible( if self.db_settings_set and generation_strategy is not None: # only local GenerationStrategies should need to be saved to # the database because only they make changes locally - if isinstance(generation_strategy, GenerationStrategy): - _save_generation_strategy_to_db_if_possible( - generation_strategy=generation_strategy, - encoder=self.db_settings.encoder, - decoder=self.db_settings.decoder, - suppress_all_errors=self._suppress_all_errors, - ) - return True + _save_generation_strategy_to_db_if_possible( + generation_strategy=generation_strategy, + encoder=self.db_settings.encoder, + decoder=self.db_settings.decoder, + suppress_all_errors=self._suppress_all_errors, + ) + return True return False def _update_generation_strategy_in_db_if_possible( self, - generation_strategy: GenerationStrategyInterface, + generation_strategy: GenerationStrategy, new_generator_runs: list[GeneratorRun], reduce_state_generator_runs: bool = False, ) -> bool: @@ -415,7 +413,7 @@ def _update_generation_strategy_in_db_if_possible( instance of `GenerationStrategy`. Args: - generation_strategy: GenerationStrategyInterface to update in DB. + generation_strategy: GenerationStrategy to update in DB. For now, only instances of GenerationStrategy will be updated. Otherwise, this function is a no-op. new_generator_runs: New generator runs of this generation strategy @@ -427,16 +425,15 @@ def _update_generation_strategy_in_db_if_possible( if self.db_settings_set: # only local GenerationStrategies should need to be saved to # the database because only they make changes locally - if isinstance(generation_strategy, GenerationStrategy): - _update_generation_strategy_in_db_if_possible( - generation_strategy=generation_strategy, - new_generator_runs=new_generator_runs, - encoder=self.db_settings.encoder, - decoder=self.db_settings.decoder, - suppress_all_errors=self._suppress_all_errors, - reduce_state_generator_runs=reduce_state_generator_runs, - ) - return True + _update_generation_strategy_in_db_if_possible( + generation_strategy=generation_strategy, + new_generator_runs=new_generator_runs, + encoder=self.db_settings.encoder, + decoder=self.db_settings.decoder, + suppress_all_errors=self._suppress_all_errors, + reduce_state_generator_runs=reduce_state_generator_runs, + ) + return True return False def _update_runner_on_experiment_in_db_if_possible( diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 065e41e0ac8..12a5d1a3bd3 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -26,7 +26,6 @@ from ax.core.batch_trial import AbandonedArm, BatchTrial from ax.core.data import Data from ax.core.experiment import DataType, Experiment -from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.map_data import MapData, MapKeyInfo from ax.core.map_metric import MapMetric @@ -2524,29 +2523,3 @@ def __init__( ) -> None: self.test_attribute = test_attribute super().__init__(name=name, lower_is_better=lower_is_better) - - -class SpecialGenerationStrategy(GenerationStrategyInterface): - """A subclass of `GenerationStrategyInterface` to be used - for testing how methods respond to subtypes other than - `GenerationStrategy`.""" - - def __init__(self) -> None: - self._name = "special" - self._generator_runs: list[GeneratorRun] = [] - - def gen_for_multiple_trials_with_multiple_models( - self, - experiment: Experiment, - data: Data | None = None, - pending_observations: dict[str, list[ObservationFeatures]] | None = None, - n: int | None = None, - num_trials: int = 1, - arms_per_node: dict[str, int] | None = None, - ) -> list[list[GeneratorRun]]: - return [] - - def clone_reset(self) -> SpecialGenerationStrategy: - clone = SpecialGenerationStrategy() - clone._name = self._name - return clone