From d9830450a327e1940b3930a2f62d7ffd22b16054 Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Fri, 1 Nov 2024 19:20:01 -0700 Subject: [PATCH] Add fine grained trial and data update methods (#3010) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3010 Some new methods to be used with reduced state loads so we can prevent data loss from resaving. Reviewed By: lena-kashtelyan Differential Revision: D65350558 fbshipit-source-id: a9c22333c5e7f0c70f609085b7ba25250b783476 --- ax/storage/sqa_store/save.py | 58 ++++++++++++++++++-- ax/storage/sqa_store/tests/test_sqa_store.py | 23 ++++++++ 2 files changed, 75 insertions(+), 6 deletions(-) diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index bbf44302fce..daa61ace128 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -226,7 +226,7 @@ def _save_or_update_trials( experiment_id: int = experiment._db_id - def add_experiment_id(sqa: SQATrial | SQAData) -> None: + def add_experiment_id(sqa: SQATrial) -> None: sqa.experiment_id = experiment_id if reduce_state_generator_runs: @@ -263,6 +263,29 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial: batch_size=batch_size, ) + save_or_update_data_for_trials( + experiment=experiment, + trials=trials, + encoder=encoder, + decoder=decoder, + batch_size=batch_size, + ) + + +def save_or_update_data_for_trials( + experiment: Experiment, + trials: list[BaseTrial], + encoder: Encoder, + decoder: Decoder, + batch_size: int | None = None, + update_trial_statuses: bool = False, +) -> None: + if experiment.db_id is None: + raise ValueError("Must save experiment before saving/updating its data.") + + def add_experiment_id(sqa: SQAData) -> None: + sqa.experiment_id = experiment.db_id + datas, data_encode_args, datas_to_keep, trial_idcs = [], [], [], [] data_sqa_class: type[SQAData] = cast( type[SQAData], encoder.config.class_to_sqa_class[Data] @@ -282,11 +305,11 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial: # For trials, for which we saved new data, we can first remove previously # saved data if it's no longer on the experiment. with session_scope() as session: - session.query(data_sqa_class).filter_by(experiment_id=experiment_id).filter( - data_sqa_class.trial_index.isnot(None) - ).filter(data_sqa_class.trial_index.in_(trial_idcs)).filter( - data_sqa_class.id not in datas_to_keep - ).delete() + session.query(data_sqa_class).filter_by( + experiment_id=experiment.db_id + ).filter(data_sqa_class.trial_index.isnot(None)).filter( + data_sqa_class.trial_index.in_(trial_idcs) + ).filter(data_sqa_class.id not in datas_to_keep).delete() _bulk_merge_into_session( objs=datas, @@ -301,6 +324,10 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial: batch_size=batch_size, ) + if update_trial_statuses: + for trial in trials: + update_trial_status(trial_with_updated_status=trial, config=encoder.config) + def update_generation_strategy( generation_strategy: GenerationStrategy, @@ -488,6 +515,25 @@ def update_properties_on_trial( ) +def update_trial_status( + trial_with_updated_status: BaseTrial, + config: SQAConfig | None = None, +) -> None: + config = SQAConfig() if config is None else config + trial_sqa_class = config.class_to_sqa_class[Trial] + + trial_id = trial_with_updated_status.db_id + if trial_id is None: + raise ValueError("Trial must be saved before being updated.") + + with session_scope() as session: + session.query(trial_sqa_class).filter_by(id=trial_id).update( + { + "status": trial_with_updated_status.status, + } + ) + + def save_analysis_cards( analysis_cards: list[AnalysisCard], experiment: Experiment, diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 18be298cef5..b3b6c3a7207 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -22,6 +22,7 @@ from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard from ax.core.arm import Arm from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose +from ax.core.base_trial import TrialStatus from ax.core.batch_trial import BatchTrial, LifecycleStage from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun @@ -72,6 +73,7 @@ update_properties_on_experiment, update_properties_on_trial, update_runner_on_experiment, + update_trial_status, ) from ax.storage.sqa_store.sqa_classes import ( SQAAbandonedArm, @@ -2060,6 +2062,27 @@ def test_update_properties_on_trial_not_saved(self) -> None: trial_with_updated_properties=experiment.trials[0], ) + def test_update_trial_status(self) -> None: + experiment = get_experiment_with_batch_trial() + self.assertEqual(experiment.trials[0].status, TrialStatus.CANDIDATE) + save_experiment(experiment) + experiment.trials[0].mark_running(no_runner_required=False) + + update_trial_status( + trial_with_updated_status=experiment.trials[0], + ) + loaded_experiment = load_experiment(experiment.name) + self.assertEqual(loaded_experiment.trials[0].status, TrialStatus.RUNNING) + + def test_update_trial_status_not_saved(self) -> None: + experiment = get_experiment_with_batch_trial() + with self.assertRaisesRegex( + ValueError, "Trial must be saved before being updated." + ): + update_trial_status( + trial_with_updated_status=experiment.trials[0], + ) + def test_RepeatedArmStorage(self) -> None: experiment = get_experiment_with_batch_trial() save_experiment(experiment)