Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move ax.preview.api to ax.api #3466

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ax/analysis/plotly/tests/test_predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ax.core.observation import ObservationFeatures
from ax.core.trial import Trial
from ax.exceptions.core import UserInputError
from ax.generation_strategy.dispatch_utils import choose_generation_strategy
from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy
from ax.modelbridge.prediction_utils import predict_at_point
from ax.modelbridge.registry import Generators
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_compute_for_requires_a_gs(self) -> None:
def test_compute_for_requires_trials(self) -> None:
analysis = PredictedEffectsPlot(metric_name="branin")
experiment = get_branin_experiment()
generation_strategy = choose_generation_strategy(
generation_strategy = choose_generation_strategy_legacy(
search_space=experiment.search_space,
experiment=experiment,
)
Expand All @@ -62,7 +62,7 @@ def test_compute_for_requires_trials(self) -> None:
def test_compute_for_requires_a_model_that_predicts(self) -> None:
analysis = PredictedEffectsPlot(metric_name="branin")
experiment = get_branin_experiment(with_batch=True, with_completed_batch=True)
generation_strategy = choose_generation_strategy(
generation_strategy = choose_generation_strategy_legacy(
search_space=experiment.search_space,
experiment=experiment,
)
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_it_does_not_plot_abandoned_trials(self) -> None:
def test_it_works_for_non_batch_experiments(self) -> None:
# GIVEN an experiment with the default generation strategy
experiment = get_branin_experiment(with_batch=False)
generation_strategy = choose_generation_strategy(
generation_strategy = choose_generation_strategy_legacy(
search_space=experiment.search_space,
experiment=experiment,
)
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/tests/test_metric_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import pandas as pd
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.metric_summary import MetricSummary
from ax.api.client import Client
from ax.api.configs import ExperimentConfig
from ax.core.metric import Metric
from ax.exceptions.core import UserInputError
from ax.preview.api.client import Client
from ax.preview.api.configs import ExperimentConfig
from ax.utils.common.testutils import TestCase


Expand Down
6 changes: 3 additions & 3 deletions ax/analysis/tests/test_search_space_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
import pandas as pd
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.search_space_summary import SearchSpaceSummary
from ax.exceptions.core import UserInputError
from ax.preview.api.client import Client
from ax.preview.api.configs import (
from ax.api.client import Client
from ax.api.configs import (
ChoiceParameterConfig,
ExperimentConfig,
ParameterScaling,
ParameterType,
RangeParameterConfig,
)
from ax.exceptions.core import UserInputError
from ax.utils.common.testutils import TestCase


Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/tests/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import pandas as pd
from ax.analysis.analysis import AnalysisCardCategory, AnalysisCardLevel
from ax.analysis.summary import Summary
from ax.api.client import Client
from ax.api.configs import ExperimentConfig, ParameterType, RangeParameterConfig
from ax.core.trial import Trial
from ax.exceptions.core import UserInputError
from ax.preview.api.client import Client
from ax.preview.api.configs import ExperimentConfig, ParameterType, RangeParameterConfig
from ax.utils.common.testutils import TestCase
from pyre_extensions import assert_is_instance, none_throws

Expand Down
6 changes: 3 additions & 3 deletions ax/preview/api/__init__.py → ax/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

# pyre-strict

from ax.preview.api.client import Client
from ax.preview.api.configs import (
from ax.api.client import Client
from ax.api.configs import (
ChoiceParameterConfig,
ExperimentConfig,
GenerationStrategyConfig,
Expand All @@ -17,7 +17,7 @@
RangeParameterConfig,
StorageConfig,
)
from ax.preview.api.types import TOutcome, TParameterization
from ax.api.types import TOutcome, TParameterization

__all__ = [
"Client",
Expand Down
44 changes: 28 additions & 16 deletions ax/preview/api/client.py → ax/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@
markdown_analysis_card_from_analysis_e,
)
from ax.analysis.utils import choose_analyses
from ax.api.configs import (
ExperimentConfig,
GenerationStrategyConfig,
OrchestrationConfig,
StorageConfig,
)
from ax.api.protocols.metric import IMetric
from ax.api.protocols.runner import IRunner
from ax.api.types import TOutcome, TParameterization
from ax.api.utils.instantiation.from_config import experiment_from_config
from ax.api.utils.instantiation.from_string import optimization_config_from_string
from ax.api.utils.storage import db_settings_from_storage_config
from ax.core.experiment import Experiment
from ax.core.metric import Metric
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
Expand All @@ -35,22 +47,8 @@
PercentileEarlyStoppingStrategy,
)
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError
from ax.generation_strategy.dispatch_utils import choose_generation_strategy
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.preview.api.configs import (
ExperimentConfig,
GenerationStrategyConfig,
OrchestrationConfig,
StorageConfig,
)
from ax.preview.api.protocols.metric import IMetric
from ax.preview.api.protocols.runner import IRunner
from ax.preview.api.types import TOutcome, TParameterization
from ax.preview.api.utils.instantiation.from_config import experiment_from_config
from ax.preview.api.utils.instantiation.from_string import (
optimization_config_from_string,
)
from ax.preview.api.utils.storage import db_settings_from_storage_config
from ax.preview.modelbridge.dispatch_utils import choose_generation_strategy
from ax.service.scheduler import Scheduler, SchedulerOptions
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.service.utils.with_db_settings_base import WithDBSettingsBase
Expand Down Expand Up @@ -179,7 +177,21 @@ def configure_generation_strategy(
"""

generation_strategy = choose_generation_strategy(
gs_config=generation_strategy_config
method=generation_strategy_config.method,
initialization_budget=generation_strategy_config.initialization_budget,
initialization_random_seed=(
generation_strategy_config.initialization_random_seed
),
use_existing_trials_for_initialization=(
generation_strategy_config.use_existing_trials_for_initialization
),
min_observed_initialization_trials=(
generation_strategy_config.min_observed_initialization_trials
),
allow_exceeding_initialization_budget=(
generation_strategy_config.allow_exceeding_initialization_budget
),
torch_device=generation_strategy_config.torch_device,
)

# Necessary for storage implications, may be removed in the future
Expand Down
29 changes: 3 additions & 26 deletions ax/preview/api/configs.py → ax/api/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from enum import Enum
from typing import Any

from ax.preview.api.types import TParameterValue
from ax.api.types import TParameterValue

from ax.generation_strategy.dispatch_utils import GenerationMethod
from ax.storage.registry_bundle import RegistryBundleBase


Expand Down Expand Up @@ -83,31 +85,6 @@ class ExperimentConfig:
owner: str | None = None


class GenerationMethod(Enum):
"""An enum to specify the desired candidate generation method for the experiment.
This is used in ``GenerationStrategyConfig``, along with the properties of the
experiment, to determine the generation strategy to use for candidate generation.

NOTE: New options should be rarely added to this enum. This is not intended to be
a list of generation strategies for the user to choose from. Instead, this enum
should only provide high level guidance to the underlying generation strategy
dispatch logic, which is responsible for determinining the exact details.

Available options are:
BALANCED: A balanced generation method that may utilize (per-metric) model
selection to achieve a good model accuracy. This method excludes expensive
methods, such as the fully Bayesian SAASBO model. Used by default.
FAST: A faster generation method that uses the built-in defaults from the
Modular BoTorch Model without any model selection.
RANDOM_SEARCH: Primarily intended for pure exploration experiments, this
method utilizes quasi-random Sobol sequences for candidate generation.
"""

BALANCED = "balanced"
FAST = "fast"
RANDOM_SEARCH = "random_search"


@dataclass
class GenerationStrategyConfig:
"""A dataclass used to configure the generation strategy used in the experiment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

# pyre-strict

from ax.preview.api.protocols.metric import IMetric
from ax.preview.api.protocols.runner import IRunner
from ax.api.protocols.metric import IMetric
from ax.api.protocols.runner import IRunner

__all__ = [
"IMetric",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Mapping
from typing import Any

from ax.preview.api.protocols.utils import _APIMetric
from ax.api.protocols.utils import _APIMetric
from pyre_extensions import override


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from collections.abc import Mapping
from typing import Any

from ax.api.protocols.utils import _APIRunner
from ax.api.types import TParameterization

from ax.core.trial_status import TrialStatus
from ax.preview.api.protocols.utils import _APIRunner
from ax.preview.api.types import TParameterization
from pyre_extensions import override


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any

import pandas as pd
from ax.api.types import TParameterization

from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.map_data import MapData, MapKeyInfo
Expand All @@ -21,7 +22,6 @@
from ax.core.runner import Runner
from ax.core.trial import Trial
from ax.exceptions.storage import JSONEncodeError
from ax.preview.api.types import TParameterization
from ax.utils.common.result import Err, Ok
from pyre_extensions import assert_is_instance, none_throws, override

Expand Down
26 changes: 13 additions & 13 deletions ax/preview/api/tests/test_client.py → ax/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@

import pandas as pd
from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot
from ax.api.client import Client
from ax.api.configs import (
ChoiceParameterConfig,
ExperimentConfig,
GenerationStrategyConfig,
OrchestrationConfig,
ParameterType,
RangeParameterConfig,
StorageConfig,
)
from ax.api.protocols.metric import IMetric
from ax.api.protocols.runner import IRunner
from ax.api.types import TParameterization

from ax.core.experiment import Experiment
from ax.core.formatting_utils import DataType
Expand All @@ -31,19 +44,6 @@
from ax.core.trial_status import TrialStatus
from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy
from ax.exceptions.core import UnsupportedError
from ax.preview.api.client import Client
from ax.preview.api.configs import (
ChoiceParameterConfig,
ExperimentConfig,
GenerationStrategyConfig,
OrchestrationConfig,
ParameterType,
RangeParameterConfig,
StorageConfig,
)
from ax.preview.api.protocols.metric import IMetric
from ax.preview.api.protocols.runner import IRunner
from ax.preview.api.types import TParameterization
from ax.storage.sqa_store.db import init_test_engine_and_session_factory
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@


import numpy as np
from ax.api.configs import (
ChoiceParameterConfig,
ExperimentConfig,
ParameterScaling,
ParameterType,
RangeParameterConfig,
)
from ax.api.utils.instantiation.from_string import parse_parameter_constraint

from ax.core.experiment import Experiment

Expand All @@ -21,14 +29,6 @@
from ax.core.parameter_constraint import validate_constraint_parameters
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.exceptions.core import UserInputError
from ax.preview.api.configs import (
ChoiceParameterConfig,
ExperimentConfig,
ParameterScaling,
ParameterType,
RangeParameterConfig,
)
from ax.preview.api.utils.instantiation.from_string import parse_parameter_constraint


def parameter_from_config(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@

# pyre-strict

from ax.api.configs import (
ChoiceParameterConfig,
ExperimentConfig,
ParameterScaling,
ParameterType,
RangeParameterConfig,
)
from ax.api.utils.instantiation.from_config import (
_parameter_type_converter,
experiment_from_config,
parameter_from_config,
)
from ax.core.experiment import Experiment
from ax.core.formatting_utils import DataType
from ax.core.parameter import (
Expand All @@ -16,18 +28,6 @@
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.exceptions.core import UserInputError
from ax.preview.api.configs import (
ChoiceParameterConfig,
ExperimentConfig,
ParameterScaling,
ParameterType,
RangeParameterConfig,
)
from ax.preview.api.utils.instantiation.from_config import (
_parameter_type_converter,
experiment_from_config,
parameter_from_config,
)
from ax.utils.common.testutils import TestCase


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@

# pyre-strict

from ax.api.utils.instantiation.from_string import (
_sanitize_dot,
optimization_config_from_string,
parse_objective,
parse_outcome_constraint,
parse_parameter_constraint,
)
from ax.core.map_metric import MapMetric
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
from ax.core.optimization_config import (
Expand All @@ -19,13 +26,6 @@
)
from ax.core.parameter_constraint import ParameterConstraint
from ax.exceptions.core import UserInputError
from ax.preview.api.utils.instantiation.from_string import (
_sanitize_dot,
optimization_config_from_string,
parse_objective,
parse_outcome_constraint,
parse_parameter_constraint,
)
from ax.utils.common.testutils import TestCase


Expand Down
Loading