Skip to content

Commit 34b9dc5

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Serialize _trial_type_to_metric_names and use base class for SQA encoding
Summary: Phase 5 of moving MultiTypeExperiment features into base Experiment. JSON storage: - `experiment_to_dict` now serializes `default_trial_type` and `_trial_type_to_metric_names` (sets converted to sorted lists for JSON determinism). - `experiment_from_json` pops and restores `_trial_type_to_metric_names` (lists back to sets), defaulting to None for backward compat with old JSON. - `multi_type_experiment_from_json` pops `_trial_type_to_metric_names` (now present via `experiment_to_dict`) to prevent it being passed as a constructor kwarg. SQA storage: - Encoder replaces the `isinstance(experiment, MultiTypeExperiment)` main branch with `experiment._default_trial_type is not None`, using the base class `metric_to_trial_type` computed property instead of MTE-specific `_metric_to_trial_type`. Keeps `isinstance` only for MTE-specific fields (`Keys.SUBCLASS` marker, `_metric_to_canonical_name`). - No decoder changes needed — `_trial_type_to_metric_names` is already populated correctly through the `__init__` + `add_tracking_metric` call chain. Differential Revision: D94994120
1 parent 379bf31 commit 34b9dc5

3 files changed

Lines changed: 25 additions & 2 deletions

File tree

ax/storage/json_store/decoder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,9 @@ def multi_type_experiment_from_json(
702702
)
703703
# not relevant to multi type experiment
704704
del object_json["runner"]
705+
# Pop _trial_type_to_metric_names if present (not a constructor argument;
706+
# rebuilt from _metric_to_trial_type below).
707+
object_json.pop("_trial_type_to_metric_names", None)
705708

706709
kwargs = {
707710
k: object_from_json(
@@ -752,6 +755,11 @@ def experiment_from_json(
752755
if _trial_type_to_runner_json is not None
753756
else None
754757
)
758+
# Pop _trial_type_to_metric_names before constructing Experiment
759+
# (it's not a constructor argument). Convert lists back to sets.
760+
_trial_type_to_metric_names_json = object_json.pop(
761+
"_trial_type_to_metric_names", None
762+
)
755763

756764
experiment = Experiment(
757765
**{
@@ -766,6 +774,10 @@ def experiment_from_json(
766774
experiment._arms_by_name = {}
767775
if _trial_type_to_runner is not None:
768776
experiment._trial_type_to_runner = _trial_type_to_runner
777+
if _trial_type_to_metric_names_json is not None:
778+
experiment._trial_type_to_metric_names = {
779+
tt: set(names) for tt, names in _trial_type_to_metric_names_json.items()
780+
}
769781

770782
_load_experiment_info(
771783
exp=experiment,

ax/storage/json_store/encoders.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,12 @@ def experiment_to_dict(experiment: Experiment) -> dict[str, Any]:
128128
"is_test": experiment.is_test,
129129
"data_by_trial": data_to_data_by_trial(data=experiment.data),
130130
"properties": experiment._properties,
131+
"default_trial_type": experiment._default_trial_type,
131132
"_trial_type_to_runner": experiment._trial_type_to_runner,
133+
"_trial_type_to_metric_names": {
134+
tt: sorted(names)
135+
for tt, names in experiment._trial_type_to_metric_names.items()
136+
},
132137
}
133138

134139

ax/storage/sqa_store/encoder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,19 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
219219
runners = []
220220
if isinstance(experiment, MultiTypeExperiment):
221221
experiment._properties[Keys.SUBCLASS] = "MultiTypeExperiment"
222+
if experiment._default_trial_type is not None:
222223
for trial_type, runner in experiment._trial_type_to_runner.items():
223224
runner_sqa = self.runner_to_sqa(none_throws(runner), trial_type)
224225
runners.append(runner_sqa)
225226

227+
metric_to_tt = experiment.metric_to_trial_type
226228
for metric in tracking_metrics:
227-
metric.trial_type = experiment._metric_to_trial_type[metric.name]
228-
if metric.name in experiment._metric_to_canonical_name:
229+
if metric.name in metric_to_tt:
230+
metric.trial_type = metric_to_tt[metric.name]
231+
if (
232+
isinstance(experiment, MultiTypeExperiment)
233+
and metric.name in experiment._metric_to_canonical_name
234+
):
229235
metric.canonical_name = experiment._metric_to_canonical_name[
230236
metric.name
231237
]

0 commit comments

Comments
 (0)