Skip to content

Commit

Permalink
Add fine grained trial and data update methods (facebook#3010)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#3010

Some new methods to be used with reduced state loads so we can prevent data loss from resaving.

Reviewed By: lena-kashtelyan

Differential Revision: D65350558

fbshipit-source-id: a9c22333c5e7f0c70f609085b7ba25250b783476
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Nov 2, 2024
1 parent 16440d2 commit d983045
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 6 deletions.
58 changes: 52 additions & 6 deletions ax/storage/sqa_store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _save_or_update_trials(

experiment_id: int = experiment._db_id

def add_experiment_id(sqa: SQATrial | SQAData) -> None:
def add_experiment_id(sqa: SQATrial) -> None:
sqa.experiment_id = experiment_id

if reduce_state_generator_runs:
Expand Down Expand Up @@ -263,6 +263,29 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial:
batch_size=batch_size,
)

save_or_update_data_for_trials(
experiment=experiment,
trials=trials,
encoder=encoder,
decoder=decoder,
batch_size=batch_size,
)


def save_or_update_data_for_trials(
experiment: Experiment,
trials: list[BaseTrial],
encoder: Encoder,
decoder: Decoder,
batch_size: int | None = None,
update_trial_statuses: bool = False,
) -> None:
if experiment.db_id is None:
raise ValueError("Must save experiment before saving/updating its data.")

def add_experiment_id(sqa: SQAData) -> None:
sqa.experiment_id = experiment.db_id

datas, data_encode_args, datas_to_keep, trial_idcs = [], [], [], []
data_sqa_class: type[SQAData] = cast(
type[SQAData], encoder.config.class_to_sqa_class[Data]
Expand All @@ -282,11 +305,11 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial:
# For trials, for which we saved new data, we can first remove previously
# saved data if it's no longer on the experiment.
with session_scope() as session:
session.query(data_sqa_class).filter_by(experiment_id=experiment_id).filter(
data_sqa_class.trial_index.isnot(None)
).filter(data_sqa_class.trial_index.in_(trial_idcs)).filter(
data_sqa_class.id not in datas_to_keep
).delete()
session.query(data_sqa_class).filter_by(
experiment_id=experiment.db_id
).filter(data_sqa_class.trial_index.isnot(None)).filter(
data_sqa_class.trial_index.in_(trial_idcs)
).filter(data_sqa_class.id not in datas_to_keep).delete()

_bulk_merge_into_session(
objs=datas,
Expand All @@ -301,6 +324,10 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial:
batch_size=batch_size,
)

if update_trial_statuses:
for trial in trials:
update_trial_status(trial_with_updated_status=trial, config=encoder.config)


def update_generation_strategy(
generation_strategy: GenerationStrategy,
Expand Down Expand Up @@ -488,6 +515,25 @@ def update_properties_on_trial(
)


def update_trial_status(
trial_with_updated_status: BaseTrial,
config: SQAConfig | None = None,
) -> None:
config = SQAConfig() if config is None else config
trial_sqa_class = config.class_to_sqa_class[Trial]

trial_id = trial_with_updated_status.db_id
if trial_id is None:
raise ValueError("Trial must be saved before being updated.")

with session_scope() as session:
session.query(trial_sqa_class).filter_by(id=trial_id).update(
{
"status": trial_with_updated_status.status,
}
)


def save_analysis_cards(
analysis_cards: list[AnalysisCard],
experiment: Experiment,
Expand Down
23 changes: 23 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import BatchTrial, LifecycleStage
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
Expand Down Expand Up @@ -72,6 +73,7 @@
update_properties_on_experiment,
update_properties_on_trial,
update_runner_on_experiment,
update_trial_status,
)
from ax.storage.sqa_store.sqa_classes import (
SQAAbandonedArm,
Expand Down Expand Up @@ -2060,6 +2062,27 @@ def test_update_properties_on_trial_not_saved(self) -> None:
trial_with_updated_properties=experiment.trials[0],
)

def test_update_trial_status(self) -> None:
experiment = get_experiment_with_batch_trial()
self.assertEqual(experiment.trials[0].status, TrialStatus.CANDIDATE)
save_experiment(experiment)
experiment.trials[0].mark_running(no_runner_required=False)

update_trial_status(
trial_with_updated_status=experiment.trials[0],
)
loaded_experiment = load_experiment(experiment.name)
self.assertEqual(loaded_experiment.trials[0].status, TrialStatus.RUNNING)

def test_update_trial_status_not_saved(self) -> None:
experiment = get_experiment_with_batch_trial()
with self.assertRaisesRegex(
ValueError, "Trial must be saved before being updated."
):
update_trial_status(
trial_with_updated_status=experiment.trials[0],
)

def test_RepeatedArmStorage(self) -> None:
experiment = get_experiment_with_batch_trial()
save_experiment(experiment)
Expand Down

0 comments on commit d983045

Please sign in to comment.