Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate GenerationStrategyInterface #3338

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import pandas as pd
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.utils.common.base import Base
from ax.utils.common.logger import get_logger
from ax.utils.common.result import Err, ExceptionE, Ok, Result
Expand Down Expand Up @@ -123,7 +123,7 @@ class Analysis(Protocol):
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> AnalysisCard:
# Note: when implementing compute always prefer experiment.lookup_data() to
# experiment.fetch_data() to avoid unintential data fetching within the report
Expand All @@ -133,7 +133,7 @@ def compute(
def compute_result(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> Result[AnalysisCard, AnalysisE]:
"""
Utility method to compute an AnalysisCard as a Result. This can be useful for
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/healthcheck/can_generate_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
HealthcheckStatus,
)
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.generation_strategy.generation_strategy import GenerationStrategy
from pyre_extensions import none_throws


Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(
def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> HealthcheckAnalysisCard:
status = HealthcheckStatus.PASS
subtitle = self.reason
Expand Down
3 changes: 1 addition & 2 deletions ax/analysis/healthcheck/constraints_feasibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ax.analysis.plotly.arm_effects.utils import get_predictions_by_arm
from ax.analysis.plotly.utils import is_predictive
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.optimization_config import OptimizationConfig
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
Expand Down Expand Up @@ -48,7 +47,7 @@ def __init__(self, prob_threshold: float = 0.95) -> None:
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> HealthcheckAnalysisCard:
r"""
Compute the feasibility of the constraints for the experiment.
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/healthcheck/healthcheck_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ax.analysis.analysis import Analysis, AnalysisCard
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.generation_strategy.generation_strategy import GenerationStrategy


class HealthcheckStatus(IntEnum):
Expand All @@ -33,5 +33,5 @@ class HealthcheckAnalysis(Analysis):
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> HealthcheckAnalysisCard: ...
4 changes: 2 additions & 2 deletions ax/analysis/healthcheck/regression_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
detect_regressions_by_trial,
)
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from pyre_extensions import none_throws


Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(self, prob_threshold: float = 0.95) -> None:
def compute(
self,
experiment: Experiment | None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> HealthcheckAnalysisCard:
r"""
Detect the regressing arms for all trials that have data.
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/healthcheck/search_space_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
HealthcheckStatus,
)
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface

from ax.core.parameter import ChoiceParameter, Parameter, RangeParameter
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import SearchSpace
from ax.core.types import TParameterization
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from pyre_extensions import assert_is_instance


Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> HealthcheckAnalysisCard:
r"""
Args:
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/healthcheck/should_generate_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
HealthcheckStatus,
)
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.generation_strategy.generation_strategy import GenerationStrategy


class ShouldGenerateCandidates(HealthcheckAnalysis):
Expand All @@ -33,7 +33,7 @@ def __init__(
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> HealthcheckAnalysisCard:
status = (
HealthcheckStatus.PASS
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/markdown/markdown_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pandas as pd
from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.generation_strategy.generation_strategy import GenerationStrategy
from IPython.display import display, Markdown


Expand All @@ -37,7 +37,7 @@ class MarkdownAnalysis(Analysis):
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> MarkdownAnalysisCard: ...

def _create_markdown_analysis_card(
Expand Down
5 changes: 2 additions & 3 deletions ax/analysis/plotly/arm_effects/insample_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.utils import is_predictive
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.core.outcome_constraint import OutcomeConstraint
from ax.exceptions.core import DataRequiredError, UserInputError
Expand Down Expand Up @@ -77,7 +76,7 @@ def __init__(
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("InSampleEffectsPlot requires an Experiment.")
Expand Down Expand Up @@ -174,7 +173,7 @@ def _get_max_observed_trial_index(model: Adapter) -> int | None:

def _get_model(
experiment: Experiment,
generation_strategy: GenerationStrategyInterface | None,
generation_strategy: GenerationStrategy | None,
use_modeled_effects: bool,
trial_index: int,
metric_name: str,
Expand Down
3 changes: 1 addition & 2 deletions ax/analysis/plotly/arm_effects/predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ax.core import OutcomeConstraint
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.base import Adapter
Expand Down Expand Up @@ -68,7 +67,7 @@ def __init__(self, metric_name: str) -> None:
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("PredictedEffectsPlot requires an Experiment.")
Expand Down
11 changes: 3 additions & 8 deletions ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.utils import select_metric
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.cross_validation import cross_validate
from plotly import express as px, graph_objects as go
from pyre_extensions import assert_is_instance, none_throws
from pyre_extensions import none_throws


class CrossValidationPlot(PlotlyAnalysis):
Expand Down Expand Up @@ -82,7 +81,7 @@ def __init__(
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> PlotlyAnalysisCard:
if generation_strategy is None:
raise UserInputError("CrossValidation requires a GenerationStrategy")
Expand All @@ -92,11 +91,7 @@ def compute(
)

df = _prepare_data(
# CrossValidationPlot requires a native Ax GenerationStrategy and cannot be
# used with a GenerationStrategyInterface.
generation_strategy=assert_is_instance(
generation_strategy, GenerationStrategy
),
generation_strategy=generation_strategy,
metric_name=metric_name,
folds=self.folds,
untransform=self.untransform,
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/plotly/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from ax.analysis.plotly.surface.utils import is_axis_log_scale
from ax.analysis.plotly.utils import select_metric
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.registry import Generators
from ax.modelbridge.torch import TorchAdapter
from ax.models.torch.botorch_modular.surrogate import Surrogate
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> PlotlyAnalysisCard:
"""
Compute Sobol index sensitivity for one metric of an experiment. Sensitivity
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/plotly/parallel_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.utils import select_metric
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from plotly import graph_objects as go


Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(self, metric_name: str | None = None) -> None:
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("ParallelCoordinatesPlot requires an Experiment")
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/plotly/plotly_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
from ax.analysis.analysis import Analysis, AnalysisCard
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.generation_strategy.generation_strategy import GenerationStrategy
from IPython.display import display, Markdown
from plotly import graph_objects as go, io as pio

Expand Down Expand Up @@ -37,7 +37,7 @@ class PlotlyAnalysis(Analysis):
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> PlotlyAnalysisCard: ...

def _create_plotly_analysis_card(
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/plotly/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import DataRequiredError, UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from plotly import express as px, graph_objects as go


Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
generation_strategy: Optional[GenerationStrategy] = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("ScatterPlot requires an Experiment")
Expand Down
3 changes: 1 addition & 2 deletions ax/analysis/plotly/surface/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from ax.analysis.plotly.utils import select_metric
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.observation import ObservationFeatures
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
Expand Down Expand Up @@ -63,7 +62,7 @@ def __init__(
def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
generation_strategy: Optional[GenerationStrategy] = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("ContourPlot requires an Experiment")
Expand Down
3 changes: 1 addition & 2 deletions ax/analysis/plotly/surface/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from ax.analysis.plotly.utils import select_metric
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.observation import ObservationFeatures
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
Expand Down Expand Up @@ -57,7 +56,7 @@ def __init__(
def compute(
self,
experiment: Optional[Experiment] = None,
generation_strategy: Optional[GenerationStrategyInterface] = None,
generation_strategy: Optional[GenerationStrategy] = None,
) -> PlotlyAnalysisCard:
if experiment is None:
raise UserInputError("SlicePlot requires an Experiment")
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy


class Summary(Analysis):
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(self, omit_empty_columns: bool = True) -> None:
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
generation_strategy: GenerationStrategy | None = None,
) -> AnalysisCard:
if experiment is None:
raise UserInputError("`Summary` analysis requires an `Experiment` input")
Expand Down
Loading