Skip to content

Commit 3d20881

Browse files
CristianLarafacebook-github-bot
authored andcommitted
Store ExperimentStatus on Experiment class (facebook#4738)
Summary: # AOSC note Land xdb schema changes first D88096914 # Summary Store ExperimentStatus on Experiment class as `status` and introduce setters/getters, update SQA classes+encoder/decoder, and add tests. I've gone ahead and also defined methods in `storage/sqa_store/` to enable efficiently updating the property. Not currently used but will be used later in the stack by the orchestrator to set this status. Differential Revision: D90089265 Privacy Context Container: L1307644
1 parent ddea79e commit 3d20881

8 files changed

Lines changed: 152 additions & 0 deletions

File tree

ax/core/experiment.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Data,
3535
sort_by_trial_index_and_arm_name,
3636
)
37+
from ax.core.experiment_status import ExperimentStatus
3738
from ax.core.generator_run import GeneratorRun
3839
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
3940
from ax.core.objective import MultiObjective
@@ -151,6 +152,7 @@ def __init__(
151152
self._optimization_config: OptimizationConfig | None = None
152153
self._tracking_metrics: dict[str, Metric] = {}
153154
self._time_created: datetime = datetime.now()
155+
self._status: ExperimentStatus | None = None
154156
self._trials: dict[int, BaseTrial] = {}
155157
self._properties: dict[str, Any] = properties or {}
156158

@@ -236,6 +238,27 @@ def experiment_type(self, experiment_type: str | None) -> None:
236238
"""Set the type of the experiment."""
237239
self._experiment_type = experiment_type
238240

241+
@property
242+
def status(self) -> ExperimentStatus | None:
243+
"""The current status of the experiment.
244+
245+
Status tracks the high-level lifecycle phase of the experiment:
246+
DRAFT, INITIALIZATION, OPTIMIZATION, COMPLETED.
247+
248+
For new experiments, status defaults to DRAFT. For legacy experiments
249+
that were created before the status field was added, status may be None.
250+
"""
251+
return self._status
252+
253+
@status.setter
254+
def status(self, status: ExperimentStatus | None) -> None:
255+
"""Set the status of the experiment.
256+
257+
Args:
258+
status: The new status for the experiment.
259+
"""
260+
self._status = status
261+
239262
@property
240263
def search_space(self) -> SearchSpace:
241264
"""The search space for this experiment.
@@ -620,6 +643,18 @@ def remove_tracking_metric(self, metric_name: str) -> Experiment:
620643
del self._tracking_metrics[metric_name]
621644
return self
622645

646+
def set_status(self, status: ExperimentStatus) -> Experiment:
647+
"""Set the status of the experiment.
648+
649+
Args:
650+
status: The new status of the experiment.
651+
652+
Returns:
653+
The experiment instance.
654+
"""
655+
self._status = status
656+
return self
657+
623658
@property
624659
def metrics(self) -> dict[str, Metric]:
625660
"""The metrics attached to the experiment."""

ax/core/tests/test_experiment.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ax.core.base_trial import BaseTrial, TrialStatus
1818
from ax.core.data import Data, sort_by_trial_index_and_arm_name
1919
from ax.core.evaluations_to_data import raw_evaluations_to_data
20+
from ax.core.experiment_status import ExperimentStatus
2021
from ax.core.map_metric import MapMetric
2122
from ax.core.metric import Metric
2223
from ax.core.objective import MultiObjective, Objective
@@ -1841,6 +1842,19 @@ def test_to_df_with_relativize(self) -> None:
18411842
"relativized value",
18421843
)
18431844

1845+
def test_experiment_status_default(self) -> None:
1846+
"""Test that new experiments have None status for backward compatibility."""
1847+
exp = get_experiment()
1848+
self.assertIsNone(exp.status)
1849+
1850+
def test_experiment_status_property(self) -> None:
1851+
"""Test the experiment status property getter and setter."""
1852+
exp = get_experiment()
1853+
1854+
# Set and get status
1855+
exp.status = ExperimentStatus.DRAFT
1856+
self.assertEqual(exp.status, ExperimentStatus.DRAFT)
1857+
18441858

18451859
class ExperimentWithMapDataTest(TestCase):
18461860
def setUp(self) -> None:

ax/storage/sqa_store/decoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def experiment_from_sqa(
384384
sq = none_throws(experiment.status_quo)
385385
experiment._register_arm(sq)
386386
experiment._time_created = experiment_sqa.time_created
387+
experiment._status = experiment_sqa.status
387388
experiment._experiment_type = self.get_enum_name(
388389
value=experiment_sqa.experiment_type, enum=self.config.experiment_type_enum
389390
)

ax/storage/sqa_store/encoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
246246
status_quo_name=status_quo_name,
247247
status_quo_parameters=status_quo_parameters,
248248
time_created=experiment.time_created,
249+
status=experiment.status,
249250
experiment_type=experiment_type,
250251
metrics=optimization_metrics + tracking_metrics,
251252
parameters=parameters,

ax/storage/sqa_store/save.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,42 @@ def update_properties_on_experiment(
536536
)
537537

538538

539+
def update_experiment_status(
540+
experiment_with_updated_status: Experiment,
541+
config: SQAConfig | None = None,
542+
) -> None:
543+
"""Update experiment status in the database.
544+
545+
This function provides an efficient way to update only the experiment's status
546+
field without re-saving the entire experiment. Use this when you need to persist
547+
status changes immediately after calling status transition methods
548+
(e.g., mark_initialization(), mark_optimization()).
549+
550+
Note: save_experiment() already handles status updates, so this function is
551+
optional. Use it when you need status-only updates for efficiency.
552+
553+
Args:
554+
experiment_with_updated_status: Experiment with updated status.
555+
config: SQAConfig to use for database operations.
556+
557+
Raises:
558+
ValueError: If experiment has not been saved to the database yet.
559+
"""
560+
config = SQAConfig() if config is None else config
561+
exp_sqa_class = config.class_to_sqa_class[Experiment]
562+
563+
exp_id = experiment_with_updated_status.db_id
564+
if exp_id is None:
565+
raise ValueError("Experiment must be saved before being updated.")
566+
567+
with session_scope() as session:
568+
session.query(exp_sqa_class).filter_by(id=exp_id).update(
569+
{
570+
"status": experiment_with_updated_status.status,
571+
}
572+
)
573+
574+
539575
def update_properties_on_trial(
540576
trial_with_updated_properties: BaseTrial,
541577
config: SQAConfig | None = None,

ax/storage/sqa_store/sqa_classes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from ax.core.evaluations_to_data import DataType
1616

17+
from ax.core.experiment_status import ExperimentStatus
1718
from ax.core.parameter import ParameterType
1819

1920
from ax.core.trial_status import TrialStatus
@@ -379,6 +380,9 @@ class SQAExperiment(Base):
379380
JSONEncodedTextDict
380381
)
381382
time_created: Column[datetime] = Column(IntTimestamp, nullable=False)
383+
status: Column[ExperimentStatus | None] = Column(
384+
IntEnum(ExperimentStatus), nullable=True
385+
)
382386
default_trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
383387
default_data_type: Column[DataType] = Column(IntEnum(DataType), nullable=True)
384388
# pyre-fixme[8]: Incompatible attribute type [8]: Attribute

ax/storage/sqa_store/tests/test_sqa_store.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,29 @@ def test_saving_an_experiment_with_type_errors_with_missing_enum_value(
527527
config=SQAConfig(experiment_type_enum=TestExperimentTypeEnum),
528528
)
529529

530+
def test_experiment_status_save_load(self) -> None:
531+
"""Test that experiment status is correctly saved and loaded."""
532+
from ax.core.experiment_status import ExperimentStatus
533+
534+
# Test None status (backward compatibility)
535+
with self.subTest(status=None):
536+
exp = get_experiment()
537+
exp._name = "test_exp_status_none"
538+
exp.status = None
539+
save_experiment(exp)
540+
loaded_exp = load_experiment(exp.name)
541+
self.assertEqual(loaded_exp.status, None)
542+
543+
# Test all ExperimentStatus enum values
544+
for status in ExperimentStatus:
545+
with self.subTest(status=status):
546+
exp = get_experiment()
547+
exp._name = f"test_exp_status_{status.name.lower()}"
548+
exp.status = status
549+
save_experiment(exp)
550+
loaded_exp = load_experiment(exp.name)
551+
self.assertEqual(loaded_exp.status, status)
552+
530553
def test_LoadExperimentTrialsInBatches(self) -> None:
531554
for _ in range(4):
532555
self.experiment.new_trial()

ax/storage/sqa_store/with_db_settings_base.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
_save_or_update_trials,
6767
_update_generation_strategy,
6868
save_analysis_card,
69+
update_experiment_status,
6970
update_properties_on_experiment,
7071
update_runner_on_experiment,
7172
)
@@ -468,6 +469,27 @@ def _update_experiment_properties_in_db(
468469
return True
469470
return False
470471

472+
def _update_experiment_status_in_db_if_possible(
473+
self,
474+
experiment_with_updated_status: Experiment,
475+
) -> bool:
476+
"""Update experiment status in the database if DB settings are configured.
477+
478+
Args:
479+
experiment_with_updated_status: Experiment with updated status.
480+
481+
Returns:
482+
True if the update was performed, False if DB settings are not configured.
483+
"""
484+
if self.db_settings_set:
485+
_update_experiment_status_in_db(
486+
experiment_with_updated_status=experiment_with_updated_status,
487+
sqa_config=self.db_settings.encoder.config,
488+
suppress_all_errors=self._suppress_all_errors,
489+
)
490+
return True
491+
return False
492+
471493
def _save_analysis_card_to_db_if_possible(
472494
self,
473495
experiment: Experiment,
@@ -625,6 +647,22 @@ def _update_experiment_properties_in_db(
625647
)
626648

627649

650+
@retry_on_exception(
651+
retries=3,
652+
default_return_on_suppression=False,
653+
exception_types=RETRY_EXCEPTION_TYPES,
654+
)
655+
def _update_experiment_status_in_db(
656+
experiment_with_updated_status: Experiment,
657+
sqa_config: SQAConfig,
658+
suppress_all_errors: bool, # Used by the decorator.
659+
) -> None:
660+
update_experiment_status(
661+
experiment_with_updated_status=experiment_with_updated_status,
662+
config=sqa_config,
663+
)
664+
665+
628666
@retry_on_exception(
629667
retries=3,
630668
default_return_on_suppression=False,

0 commit comments

Comments
 (0)