Skip to content

Commit ccf3dea

Browse files
mpolson64facebook-github-bot
authored andcommitted
Serialize _trial_type_to_metric_names and use base class for SQA encoding (facebook#4999)
Summary: Phase 5 of moving MultiTypeExperiment features into base Experiment. JSON storage: - `experiment_to_dict` now serializes `default_trial_type` and `_trial_type_to_metric_names` (sets converted to sorted lists for JSON determinism). - `experiment_from_json` pops and restores `_trial_type_to_metric_names` (lists back to sets), defaulting to None for backward compat with old JSON. - `multi_type_experiment_from_json` pops `_trial_type_to_metric_names` (now present via `experiment_to_dict`) to prevent it being passed as a constructor kwarg. SQA storage: - Encoder replaces the `isinstance(experiment, MultiTypeExperiment)` main branch with `experiment._default_trial_type is not None`, using the base class `metric_to_trial_type` computed property instead of MTE-specific `_metric_to_trial_type`. Keeps `isinstance` only for MTE-specific fields (`Keys.SUBCLASS` marker, `_metric_to_canonical_name`). - No decoder changes needed — `_trial_type_to_metric_names` is already populated correctly through the `__init__` + `add_tracking_metric` call chain. Reviewed By: lena-kashtelyan Differential Revision: D94994120
1 parent 487a517 commit ccf3dea

3 files changed

Lines changed: 25 additions & 4 deletions

File tree

ax/storage/json_store/decoder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,9 @@ def multi_type_experiment_from_json(
702702
)
703703
# not relevant to multi type experiment
704704
del object_json["runner"]
705+
# Pop _trial_type_to_metric_names if present (not a constructor argument;
706+
# rebuilt from _metric_to_trial_type below).
707+
object_json.pop("_trial_type_to_metric_names", None)
705708

706709
kwargs = {
707710
k: object_from_json(
@@ -752,6 +755,11 @@ def experiment_from_json(
752755
if _trial_type_to_runner_json is not None
753756
else None
754757
)
758+
# Pop _trial_type_to_metric_names before constructing Experiment
759+
# (it's not a constructor argument). Convert lists back to sets.
760+
_trial_type_to_metric_names_json = object_json.pop(
761+
"_trial_type_to_metric_names", None
762+
)
755763

756764
experiment = Experiment(
757765
**{
@@ -766,6 +774,10 @@ def experiment_from_json(
766774
experiment._arms_by_name = {}
767775
if _trial_type_to_runner is not None:
768776
experiment._trial_type_to_runner = _trial_type_to_runner
777+
if _trial_type_to_metric_names_json is not None:
778+
experiment._trial_type_to_metric_names = {
779+
tt: set(names) for tt, names in _trial_type_to_metric_names_json.items()
780+
}
769781

770782
_load_experiment_info(
771783
exp=experiment,

ax/storage/json_store/encoders.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,17 +128,20 @@ def experiment_to_dict(experiment: Experiment) -> dict[str, Any]:
128128
"is_test": experiment.is_test,
129129
"data_by_trial": data_to_data_by_trial(data=experiment.data),
130130
"properties": experiment._properties,
131+
"default_trial_type": experiment._default_trial_type,
131132
"_trial_type_to_runner": experiment._trial_type_to_runner,
133+
"_trial_type_to_metric_names": {
134+
tt: sorted(names)
135+
for tt, names in experiment._trial_type_to_metric_names.items()
136+
},
132137
}
133138

134139

135140
def multi_type_experiment_to_dict(experiment: MultiTypeExperiment) -> dict[str, Any]:
136141
"""Convert AE multitype experiment to a dictionary."""
137142
multi_type_dict = {
138-
"default_trial_type": experiment._default_trial_type,
139143
"_metric_to_canonical_name": experiment._metric_to_canonical_name,
140144
"_metric_to_trial_type": experiment._metric_to_trial_type,
141-
"_trial_type_to_runner": experiment._trial_type_to_runner,
142145
}
143146
multi_type_dict.update(experiment_to_dict(experiment))
144147
return multi_type_dict

ax/storage/sqa_store/encoder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,19 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
222222
runners = []
223223
if isinstance(experiment, MultiTypeExperiment):
224224
experiment._properties[Keys.SUBCLASS] = "MultiTypeExperiment"
225+
if experiment._default_trial_type is not None:
225226
for trial_type, runner in experiment._trial_type_to_runner.items():
226227
runner_sqa = self.runner_to_sqa(none_throws(runner), trial_type)
227228
runners.append(runner_sqa)
228229

230+
metric_to_tt = experiment.metric_to_trial_type
229231
for metric in tracking_metrics:
230-
metric.trial_type = experiment._metric_to_trial_type[metric.name]
231-
if metric.name in experiment._metric_to_canonical_name:
232+
if metric.name in metric_to_tt:
233+
metric.trial_type = metric_to_tt[metric.name]
234+
if (
235+
isinstance(experiment, MultiTypeExperiment)
236+
and metric.name in experiment._metric_to_canonical_name
237+
):
232238
metric.canonical_name = experiment._metric_to_canonical_name[
233239
metric.name
234240
]

0 commit comments

Comments
 (0)