Skip to content

Commit 71209a1

Browse files
CristianLarameta-codesync[bot]
authored andcommitted
Store ExperimentStatus on Experiment class (facebook#4738)
Summary: Pull Request resolved: facebook#4738 # AOSC note Land xdb schema changes first D88096914 # Summary Store ExperimentStatus on Experiment class as `status` and introduce setters/getters, update SQA classes+encoder/decoder, and add tests. Differential Revision: D90089265
1 parent cecc7be commit 71209a1

6 files changed

Lines changed: 61 additions & 0 deletions

File tree

ax/core/experiment.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ax.core.base_trial import BaseTrial
3030
from ax.core.batch_trial import BatchTrial
3131
from ax.core.data import combine_data_rows_favoring_recent, Data
32+
from ax.core.experiment_status import ExperimentStatus
3233
from ax.core.generator_run import GeneratorRun
3334
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
3435
from ax.core.objective import MultiObjective
@@ -146,6 +147,7 @@ def __init__(
146147
self._optimization_config: OptimizationConfig | None = None
147148
self._tracking_metrics: dict[str, Metric] = {}
148149
self._time_created: datetime = datetime.now()
150+
self._status: ExperimentStatus | None = None
149151
self._trials: dict[int, BaseTrial] = {}
150152
self._properties: dict[str, Any] = properties or {}
151153

@@ -231,6 +233,27 @@ def experiment_type(self, experiment_type: str | None) -> None:
231233
"""Set the type of the experiment."""
232234
self._experiment_type = experiment_type
233235

236+
@property
237+
def status(self) -> ExperimentStatus | None:
238+
"""The current status of the experiment.
239+
240+
Status tracks the high-level lifecycle phase of the experiment:
241+
DRAFT, INITIALIZATION, OPTIMIZATION, COMPLETED.
242+
243+
For new experiments, status defaults to DRAFT. For legacy experiments
244+
that were created before the status field was added, status may be None.
245+
"""
246+
return self._status
247+
248+
@status.setter
249+
def status(self, status: ExperimentStatus | None) -> None:
250+
"""Set the status of the experiment.
251+
252+
Args:
253+
status: The new status for the experiment.
254+
"""
255+
self._status = status
256+
234257
@property
235258
def search_space(self) -> SearchSpace:
236259
"""The search space for this experiment.

ax/core/tests/test_experiment.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ax.core.base_trial import BaseTrial, TrialStatus
1818
from ax.core.data import Data, sort_by_trial_index_and_arm_name
1919
from ax.core.evaluations_to_data import raw_evaluations_to_data
20+
from ax.core.experiment_status import ExperimentStatus
2021
from ax.core.map_metric import MapMetric
2122
from ax.core.metric import Metric
2223
from ax.core.objective import MultiObjective, Objective
@@ -1853,6 +1854,15 @@ def test_to_df_with_relativize(self) -> None:
18531854
"relativized value",
18541855
)
18551856

1857+
def test_experiment_status_default(self) -> None:
1858+
"""Test that new experiments have None status for backward compatibility."""
1859+
self.assertIsNone(self.experiment.status)
1860+
1861+
def test_experiment_status_property(self) -> None:
1862+
"""Test the experiment status property getter and setter."""
1863+
self.experiment.status = ExperimentStatus.DRAFT
1864+
self.assertEqual(self.experiment.status, ExperimentStatus.DRAFT)
1865+
18561866

18571867
class ExperimentWithMapDataTest(TestCase):
18581868
def setUp(self) -> None:

ax/storage/sqa_store/decoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ def experiment_from_sqa(
408408
_cast_arm_parameters(sq, experiment.search_space)
409409
experiment._register_arm(sq)
410410
experiment._time_created = experiment_sqa.time_created
411+
experiment._status = experiment_sqa.status
411412
experiment._experiment_type = self.get_enum_name(
412413
value=experiment_sqa.experiment_type, enum=self.config.experiment_type_enum
413414
)

ax/storage/sqa_store/encoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
246246
status_quo_name=status_quo_name,
247247
status_quo_parameters=status_quo_parameters,
248248
time_created=experiment.time_created,
249+
status=experiment.status,
249250
experiment_type=experiment_type,
250251
metrics=optimization_metrics + tracking_metrics,
251252
parameters=parameters,

ax/storage/sqa_store/sqa_classes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Any
1414

1515
from ax.core.evaluations_to_data import DataType
16+
from ax.core.experiment_status import ExperimentStatus
1617
from ax.core.parameter import ParameterType
1718
from ax.core.trial_status import TrialStatus
1819
from ax.core.types import (
@@ -375,6 +376,9 @@ class SQAExperiment(Base):
375376
JSONEncodedTextDict
376377
)
377378
time_created: Column[datetime] = Column(IntTimestamp, nullable=False)
379+
status: Column[ExperimentStatus | None] = Column(
380+
IntEnum(ExperimentStatus), nullable=True
381+
)
378382
default_trial_type: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
379383
default_data_type: Column[DataType] = Column(IntEnum(DataType), nullable=True)
380384
# pyre-fixme[8]: Incompatible attribute type [8]: Attribute

ax/storage/sqa_store/tests/test_sqa_store.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
TransferLearningMetadata,
3333
)
3434
from ax.core.experiment import Experiment
35+
from ax.core.experiment_status import ExperimentStatus
3536
from ax.core.generator_run import GeneratorRun
3637
from ax.core.metric import Metric
3738
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
@@ -614,6 +615,27 @@ def test_saving_an_experiment_with_type_errors_with_missing_enum_value(
614615
config=SQAConfig(experiment_type_enum=MockExperimentTypeEnum),
615616
)
616617

618+
def test_experiment_status_save_load(self) -> None:
619+
"""Test that experiment status is correctly saved and loaded."""
620+
# Test None status (backward compatibility)
621+
with self.subTest(status=None):
622+
exp = get_experiment()
623+
exp._name = "test_exp_status_none"
624+
exp.status = None
625+
save_experiment(exp)
626+
loaded_exp = load_experiment(exp.name)
627+
self.assertEqual(loaded_exp.status, None)
628+
629+
# Test all ExperimentStatus enum values
630+
for status in ExperimentStatus:
631+
with self.subTest(status=status):
632+
exp = get_experiment()
633+
exp._name = f"test_exp_status_{status.name.lower()}"
634+
exp.status = status
635+
save_experiment(exp)
636+
loaded_exp = load_experiment(exp.name)
637+
self.assertEqual(loaded_exp.status, status)
638+
617639
def test_load_experiment_trials_in_batches(self) -> None:
618640
for _ in range(4):
619641
self.experiment.new_trial()

0 commit comments

Comments
 (0)