diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index dd63d35349f..4e48cddb621 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -66,6 +66,7 @@ from ax.plot.feature_importances import plot_feature_importance_by_feature from ax.plot.helper import _format_dict from ax.plot.trace import optimization_trace_single_method +from ax.service.utils.analysis_base import AnalysisBase from ax.service.utils.best_point_mixin import BestPointMixin from ax.service.utils.instantiation import ( FixedFeatures, @@ -73,7 +74,7 @@ ObjectiveProperties, ) from ax.service.utils.report_utils import exp_to_df -from ax.service.utils.with_db_settings_base import DBSettings, WithDBSettingsBase +from ax.service.utils.with_db_settings_base import DBSettings from ax.storage.json_store.decoder import ( generation_strategy_from_json, object_from_json, @@ -108,7 +109,7 @@ ) -class AxClient(WithDBSettingsBase, BestPointMixin, InstantiationBase): +class AxClient(AnalysisBase, BestPointMixin, InstantiationBase): """ Convenience handler for management of experimentation cycle through a service-like API. External system manages scheduling of the cycle and makes diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index f4634404cb6..dbc596b281b 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -8,8 +8,6 @@ from __future__ import annotations -import traceback - from collections.abc import Callable, Generator, Iterable, Mapping from copy import deepcopy from dataclasses import dataclass @@ -20,10 +18,6 @@ from typing import Any, cast, NamedTuple, Optional import ax.service.utils.early_stopping as early_stopping_utils -import pandas as pd -from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE -from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard -from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface @@ -57,6 +51,7 @@ from ax.modelbridge.base import ModelBridge from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.modelbridge_utils import get_fixed_features_from_experiment +from ax.service.utils.analysis_base import AnalysisBase from ax.service.utils.best_point_mixin import BestPointMixin from ax.service.utils.scheduler_options import SchedulerOptions, TrialType from ax.service.utils.with_db_settings_base import DBSettings, WithDBSettingsBase @@ -70,7 +65,6 @@ set_ax_logger_levels, ) from ax.utils.common.timeutils import current_timestamp_in_millis -from ax.utils.common.typeutils import checked_cast from pyre_extensions import assert_is_instance, none_throws @@ -151,7 +145,7 @@ def append(self, text: str) -> None: self.text += text -class Scheduler(WithDBSettingsBase, BestPointMixin): +class Scheduler(AnalysisBase, BestPointMixin): """Closed-loop manager class for Ax optimization. Attributes: @@ -678,62 +672,6 @@ def run_all_trials( idle_callback=idle_callback, ) - def compute_analyses( - self, analyses: Iterable[Analysis] | None = None - ) -> list[AnalysisCard]: - """ - Compute Analyses for the Experiment and GenerationStrategy associated with this - Scheduler instance and save them to the DB if possible. If an Analysis fails to - compute (e.g. due to a missing metric), it will be skipped and a warning will - be logged. - - Args: - analyses: Analyses to compute. If None, the Scheduler will choose a set of - Analyses to compute based on the Experiment and GenerationStrategy. - """ - analyses = analyses if analyses is not None else self._choose_analyses() - - results = [ - analysis.compute_result( - experiment=self.experiment, generation_strategy=self.generation_strategy - ) - for analysis in analyses - ] - - # TODO Accumulate Es into their own card, perhaps via unwrap_or_else - cards = [result.unwrap() for result in results if result.is_ok()] - - for result in results: - if result.is_err(): - e = checked_cast(AnalysisE, result.err) - traceback_str = "".join( - traceback.format_exception( - type(result.err.exception), - e.exception, - e.exception.__traceback__, - ) - ) - cards.append( - MarkdownAnalysisCard( - name=e.analysis.name, - # It would be better if we could reliably compute the title - # without risking another error - title=f"{e.analysis.name} Error", - subtitle=f"An error occurred while computing {e.analysis}", - attributes=e.analysis.attributes, - blob=traceback_str, - df=pd.DataFrame(), - level=AnalysisCardLevel.DEBUG, - ) - ) - - self._save_analysis_cards_to_db_if_possible( - analysis_cards=cards, - experiment=self.experiment, - ) - - return cards - def run_trials_and_yield_results( self, max_trials: int, @@ -1906,14 +1844,6 @@ def _get_next_trials( trials.append(trial) return trials, None - def _choose_analyses(self) -> list[Analysis]: - """ - Choose Analyses to compute based on the Experiment, GenerationStrategy, etc. - """ - - # TODO Create a useful heuristic for choosing analyses - return [ParallelCoordinatesPlot()] - def _gen_new_trials_from_generation_strategy( self, num_trials: int, diff --git a/ax/service/utils/analysis_base.py b/ax/service/utils/analysis_base.py new file mode 100644 index 00000000000..30958ff7f6c --- /dev/null +++ b/ax/service/utils/analysis_base.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import traceback +from typing import Iterable + +import pandas as pd + +from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE +from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard +from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot +from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.service.utils.with_db_settings_base import WithDBSettingsBase +from ax.utils.common.typeutils import checked_cast + + +class AnalysisBase(WithDBSettingsBase): + """ + Base class for analysis functionality shared between AxClient and Scheduler. + """ + + # pyre-fixme[13]: Attribute `experiment` is declared in class + # `AnalysisBase` to have type `Experiment` but is never initialized + experiment: Experiment + # pyre-fixme[13]: Attribute `generation_strategy` is declared in class + # `AnalysisBase` to have type `GenerationStrategyInterface` but + # is never initialized + generation_strategy: GenerationStrategyInterface + + def _choose_analyses(self) -> list[Analysis]: + """ + Choose Analyses to compute based on the Experiment, GenerationStrategy, etc. + """ + + # TODO Create a useful heuristic for choosing analyses + return [ParallelCoordinatesPlot()] + + def compute_analyses( + self, analyses: Iterable[Analysis] | None = None + ) -> list[AnalysisCard]: + """ + Compute Analyses for the Experiment and GenerationStrategy associated with this + Scheduler instance and save them to the DB if possible. If an Analysis fails to + compute (e.g. due to a missing metric), it will be skipped and a warning will + be logged. + + Args: + analyses: Analyses to compute. If None, the Scheduler will choose a set of + Analyses to compute based on the Experiment and GenerationStrategy. + """ + analyses = analyses if analyses is not None else self._choose_analyses() + + results = [ + analysis.compute_result( + experiment=self.experiment, + generation_strategy=self.generation_strategy, + ) + for analysis in analyses + ] + + # TODO Accumulate Es into their own card, perhaps via unwrap_or_else + cards = [result.unwrap() for result in results if result.is_ok()] + + for result in results: + if result.is_err(): + e = checked_cast(AnalysisE, result.err) + traceback_str = "".join( + traceback.format_exception( + type(result.err.exception), + e.exception, + e.exception.__traceback__, + ) + ) + cards.append( + MarkdownAnalysisCard( + name=e.analysis.name, + # It would be better if we could reliably compute the title + # without risking another error + title=f"{e.analysis.name} Error", + subtitle=f"An error occurred while computing {e.analysis}", + attributes=e.analysis.attributes, + blob=traceback_str, + df=pd.DataFrame(), + level=AnalysisCardLevel.DEBUG, + ) + ) + + self._save_analysis_cards_to_db_if_possible( + analysis_cards=cards, + experiment=self.experiment, + ) + + return cards diff --git a/sphinx/source/service.rst b/sphinx/source/service.rst index fdbb67fb320..e66454a97bb 100644 --- a/sphinx/source/service.rst +++ b/sphinx/source/service.rst @@ -49,6 +49,15 @@ Scheduler Utils ----- +Analysis +~~~~~~~~ + +.. automodule:: ax.service.utils.analysis_base + :members: + :undoc-members: + :show-inheritance: + + Best Point Identification ~~~~~~~~~~~~~~~~~~~~~~~~~