Skip to content

Commit

Permalink
Support skip_runners_and_metrics option (facebook#3105)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#3105

Support options for not loading runners and metric types for MultiTypeExperiment

Reviewed By: sdaulton

Differential Revision: D66280618

fbshipit-source-id: 1a5c0acc6b9737b1ed0eaf056a079fb49554fe9d
  • Loading branch information
andycylmeta authored and facebook-github-bot committed Nov 22, 2024
1 parent 8fe3fed commit 1d25234
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 6 deletions.
4 changes: 2 additions & 2 deletions ax/core/multi_type_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
name: str,
search_space: SearchSpace,
default_trial_type: str,
default_runner: Runner,
default_runner: Runner | None,
optimization_config: OptimizationConfig | None = None,
tracking_metrics: list[Metric] | None = None,
status_quo: Arm | None = None,
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(
self._default_trial_type = default_trial_type

# Map from trial type to default runner of that type
self._trial_type_to_runner: dict[str, Runner] = {
self._trial_type_to_runner: dict[str, Runner | None] = {
default_trial_type: default_runner
}

Expand Down
24 changes: 21 additions & 3 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def _init_experiment_from_sqa(
)

def _init_mt_experiment_from_sqa(
self, experiment_sqa: SQAExperiment
self,
experiment_sqa: SQAExperiment,
) -> MultiTypeExperiment:
"""First step of conversion within experiment_from_sqa."""
opt_config, tracking_metrics = self.opt_config_and_tracking_metrics_from_sqa(
Expand All @@ -217,24 +218,41 @@ def _init_mt_experiment_from_sqa(
if experiment_sqa.status_quo_parameters is not None
else None
)

default_trial_type = none_throws(experiment_sqa.default_trial_type)
trial_type_to_runner = {
none_throws(sqa_runner.trial_type): self.runner_from_sqa(sqa_runner)
for sqa_runner in experiment_sqa.runners
}
default_trial_type = none_throws(experiment_sqa.default_trial_type)
if len(trial_type_to_runner) == 0:
trial_type_to_runner = {default_trial_type: None}
trial_types_with_metrics = {
metric.trial_type
for metric in experiment_sqa.metrics
if metric.trial_type
}
# trial_type_to_runner is instantiated to map all trial types to None,
# so the trial types are associated with the expeirment. This is
# important for adding metrics.
trial_type_to_runner.update(
{t_type: None for t_type in trial_types_with_metrics}
)
properties = dict(experiment_sqa.properties or {})
default_data_type = experiment_sqa.default_data_type
experiment = MultiTypeExperiment(
name=experiment_sqa.name,
description=experiment_sqa.description,
search_space=search_space,
default_trial_type=default_trial_type,
default_runner=trial_type_to_runner[default_trial_type],
default_runner=trial_type_to_runner.get(default_trial_type),
optimization_config=opt_config,
status_quo=status_quo,
properties=properties,
default_data_type=default_data_type,
)
# pyre-ignore Imcompatible attribute type [8]: attribute _trial_type_to_runner
# has type Dict[str, Optional[Runner]] but is used as type
# Uniont[Dict[str, Optional[Runner]], Dict[str, None]]
experiment._trial_type_to_runner = trial_type_to_runner
sqa_metric_dict = {metric.name: metric for metric in experiment_sqa.metrics}
for tracking_metric in tracking_metrics:
Expand Down
2 changes: 1 addition & 1 deletion ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
if isinstance(experiment, MultiTypeExperiment):
properties[Keys.SUBCLASS] = "MultiTypeExperiment"
for trial_type, runner in experiment._trial_type_to_runner.items():
runner_sqa = self.runner_to_sqa(runner, trial_type)
runner_sqa = self.runner_to_sqa(none_throws(runner), trial_type)
runners.append(runner_sqa)

for metric in tracking_metrics:
Expand Down
17 changes: 17 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,23 @@ def test_MTExperimentSaveAndLoad(self) -> None:
self.assertEqual(loaded_experiment._metric_to_canonical_name["m2"], "m1")
self.assertEqual(len(loaded_experiment.trials), 2)

def test_MTExperimentSaveAndLoadSkipRunnersAndMetrics(self) -> None:
experiment = get_multi_type_experiment(add_trials=True)
save_experiment(experiment)
loaded_experiment = load_experiment(
experiment.name, skip_runners_and_metrics=True
)
self.assertEqual(loaded_experiment.default_trial_type, "type1")
# pyre-fixme[16]: `Experiment` has no attribute `_trial_type_to_runner`.
self.assertIsNone(loaded_experiment._trial_type_to_runner["type1"])
self.assertIsNone(loaded_experiment._trial_type_to_runner["type2"])
# pyre-fixme[16]: `Experiment` has no attribute `metric_to_trial_type`.
self.assertEqual(loaded_experiment.metric_to_trial_type["m1"], "type1")
self.assertEqual(loaded_experiment.metric_to_trial_type["m2"], "type2")
# pyre-fixme[16]: `Experiment` has no attribute `_metric_to_canonical_name`.
self.assertEqual(loaded_experiment._metric_to_canonical_name["m2"], "m1")
self.assertEqual(len(loaded_experiment.trials), 2)

def test_ExperimentNewTrial(self) -> None:
# Create a new trial without data
save_experiment(self.experiment)
Expand Down

0 comments on commit 1d25234

Please sign in to comment.