Skip to content

Commit 1d25234

Browse files
andycylmetafacebook-github-bot
authored andcommitted
Support skip_runners_and_metrics option (#3105)
Summary: Pull Request resolved: #3105 Support options for not loading runners and metric types for MultiTypeExperiment Reviewed By: sdaulton Differential Revision: D66280618 fbshipit-source-id: 1a5c0acc6b9737b1ed0eaf056a079fb49554fe9d
1 parent 8fe3fed commit 1d25234

File tree

4 files changed

+41
-6
lines changed

4 files changed

+41
-6
lines changed

ax/core/multi_type_experiment.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(
4747
name: str,
4848
search_space: SearchSpace,
4949
default_trial_type: str,
50-
default_runner: Runner,
50+
default_runner: Runner | None,
5151
optimization_config: OptimizationConfig | None = None,
5252
tracking_metrics: list[Metric] | None = None,
5353
status_quo: Arm | None = None,
@@ -79,7 +79,7 @@ def __init__(
7979
self._default_trial_type = default_trial_type
8080

8181
# Map from trial type to default runner of that type
82-
self._trial_type_to_runner: dict[str, Runner] = {
82+
self._trial_type_to_runner: dict[str, Runner | None] = {
8383
default_trial_type: default_runner
8484
}
8585

ax/storage/sqa_store/decoder.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ def _init_experiment_from_sqa(
197197
)
198198

199199
def _init_mt_experiment_from_sqa(
200-
self, experiment_sqa: SQAExperiment
200+
self,
201+
experiment_sqa: SQAExperiment,
201202
) -> MultiTypeExperiment:
202203
"""First step of conversion within experiment_from_sqa."""
203204
opt_config, tracking_metrics = self.opt_config_and_tracking_metrics_from_sqa(
@@ -217,24 +218,41 @@ def _init_mt_experiment_from_sqa(
217218
if experiment_sqa.status_quo_parameters is not None
218219
else None
219220
)
221+
222+
default_trial_type = none_throws(experiment_sqa.default_trial_type)
220223
trial_type_to_runner = {
221224
none_throws(sqa_runner.trial_type): self.runner_from_sqa(sqa_runner)
222225
for sqa_runner in experiment_sqa.runners
223226
}
224-
default_trial_type = none_throws(experiment_sqa.default_trial_type)
227+
if len(trial_type_to_runner) == 0:
228+
trial_type_to_runner = {default_trial_type: None}
229+
trial_types_with_metrics = {
230+
metric.trial_type
231+
for metric in experiment_sqa.metrics
232+
if metric.trial_type
233+
}
234+
# trial_type_to_runner is instantiated to map all trial types to None,
235+
# so the trial types are associated with the expeirment. This is
236+
# important for adding metrics.
237+
trial_type_to_runner.update(
238+
{t_type: None for t_type in trial_types_with_metrics}
239+
)
225240
properties = dict(experiment_sqa.properties or {})
226241
default_data_type = experiment_sqa.default_data_type
227242
experiment = MultiTypeExperiment(
228243
name=experiment_sqa.name,
229244
description=experiment_sqa.description,
230245
search_space=search_space,
231246
default_trial_type=default_trial_type,
232-
default_runner=trial_type_to_runner[default_trial_type],
247+
default_runner=trial_type_to_runner.get(default_trial_type),
233248
optimization_config=opt_config,
234249
status_quo=status_quo,
235250
properties=properties,
236251
default_data_type=default_data_type,
237252
)
253+
# pyre-ignore Imcompatible attribute type [8]: attribute _trial_type_to_runner
254+
# has type Dict[str, Optional[Runner]] but is used as type
255+
# Uniont[Dict[str, Optional[Runner]], Dict[str, None]]
238256
experiment._trial_type_to_runner = trial_type_to_runner
239257
sqa_metric_dict = {metric.name: metric for metric in experiment_sqa.metrics}
240258
for tracking_metric in tracking_metrics:

ax/storage/sqa_store/encoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
195195
if isinstance(experiment, MultiTypeExperiment):
196196
properties[Keys.SUBCLASS] = "MultiTypeExperiment"
197197
for trial_type, runner in experiment._trial_type_to_runner.items():
198-
runner_sqa = self.runner_to_sqa(runner, trial_type)
198+
runner_sqa = self.runner_to_sqa(none_throws(runner), trial_type)
199199
runners.append(runner_sqa)
200200

201201
for metric in tracking_metrics:

ax/storage/sqa_store/tests/test_sqa_store.py

+17
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,23 @@ def test_MTExperimentSaveAndLoad(self) -> None:
609609
self.assertEqual(loaded_experiment._metric_to_canonical_name["m2"], "m1")
610610
self.assertEqual(len(loaded_experiment.trials), 2)
611611

612+
def test_MTExperimentSaveAndLoadSkipRunnersAndMetrics(self) -> None:
613+
experiment = get_multi_type_experiment(add_trials=True)
614+
save_experiment(experiment)
615+
loaded_experiment = load_experiment(
616+
experiment.name, skip_runners_and_metrics=True
617+
)
618+
self.assertEqual(loaded_experiment.default_trial_type, "type1")
619+
# pyre-fixme[16]: `Experiment` has no attribute `_trial_type_to_runner`.
620+
self.assertIsNone(loaded_experiment._trial_type_to_runner["type1"])
621+
self.assertIsNone(loaded_experiment._trial_type_to_runner["type2"])
622+
# pyre-fixme[16]: `Experiment` has no attribute `metric_to_trial_type`.
623+
self.assertEqual(loaded_experiment.metric_to_trial_type["m1"], "type1")
624+
self.assertEqual(loaded_experiment.metric_to_trial_type["m2"], "type2")
625+
# pyre-fixme[16]: `Experiment` has no attribute `_metric_to_canonical_name`.
626+
self.assertEqual(loaded_experiment._metric_to_canonical_name["m2"], "m1")
627+
self.assertEqual(len(loaded_experiment.trials), 2)
628+
612629
def test_ExperimentNewTrial(self) -> None:
613630
# Create a new trial without data
614631
save_experiment(self.experiment)

0 commit comments

Comments
 (0)