Skip to content

Commit e3fe2fa

Browse files
Lena Kashtelyanfacebook-github-bot
Lena Kashtelyan
authored andcommitted
Propagate reduced_state setting when loading the target experiment, to how the auxiliary ones are loaded (facebook#3597)
Summary: Pull Request resolved: facebook#3597 This not being a propagated was a bug. Reviewed By: Cesar-Cardoso Differential Revision: D72266152 fbshipit-source-id: a6f2937555db9cd34f627e450b0c469db054099b
1 parent 9b442ef commit e3fe2fa

File tree

2 files changed

+76
-8
lines changed

2 files changed

+76
-8
lines changed

ax/storage/sqa_store/decoder.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def get_enum_name(
103103
raise SQADecodeError(f"Value {value} is invalid for enum {enum}.")
104104

105105
def _auxiliary_experiments_by_purpose_from_experiment_sqa(
106-
self, experiment_sqa: SQAExperiment
106+
self, experiment_sqa: SQAExperiment, reduced_state: bool = False
107107
) -> dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]] | None:
108108
auxiliary_experiments_by_purpose = None
109109
if experiment_sqa.auxiliary_experiments_by_purpose:
@@ -122,7 +122,9 @@ def _auxiliary_experiments_by_purpose_from_experiment_sqa(
122122
if isinstance(aux_exp_json, str):
123123
aux_exp_json = {"experiment_name": aux_exp_json}
124124
aux_experiment = auxiliary_experiment_from_json(
125-
json=aux_exp_json, config=self.config
125+
json=aux_exp_json,
126+
config=self.config,
127+
reduced_state=reduced_state,
126128
)
127129
auxiliary_experiments_by_purpose[aux_exp_purpose].append(
128130
aux_experiment
@@ -133,6 +135,7 @@ def _init_experiment_from_sqa(
133135
self,
134136
experiment_sqa: SQAExperiment,
135137
load_auxiliary_experiments: bool = True,
138+
reduced_state: bool = False,
136139
) -> Experiment:
137140
"""First step of conversion within experiment_from_sqa."""
138141
opt_config, tracking_metrics = self.opt_config_and_tracking_metrics_from_sqa(
@@ -170,7 +173,8 @@ def _init_experiment_from_sqa(
170173
auxiliary_experiments_by_purpose = (
171174
(
172175
self._auxiliary_experiments_by_purpose_from_experiment_sqa(
173-
experiment_sqa=experiment_sqa
176+
experiment_sqa=experiment_sqa,
177+
reduced_state=reduced_state,
174178
)
175179
)
176180
if load_auxiliary_experiments
@@ -281,6 +285,7 @@ def experiment_from_sqa(
281285
experiment = self._init_experiment_from_sqa(
282286
experiment_sqa,
283287
load_auxiliary_experiments=load_auxiliary_experiments,
288+
reduced_state=reduced_state,
284289
)
285290
trials = [
286291
self.trial_from_sqa(
@@ -1333,7 +1338,9 @@ def _get_scalarized_outcome_constraint_children_metrics(
13331338

13341339

13351340
def auxiliary_experiment_from_json(
1336-
json: dict[str, Any], config: SQAConfig
1341+
json: dict[str, Any],
1342+
config: SQAConfig,
1343+
reduced_state: bool = False,
13371344
) -> AuxiliaryExperiment:
13381345
"""
13391346
Load an ``AuxiliaryExperiment`` from JSON.
@@ -1353,5 +1360,6 @@ def auxiliary_experiment_from_json(
13531360
config=config,
13541361
skip_runners_and_metrics=True,
13551362
load_auxiliary_experiments=False,
1363+
reduced_state=reduced_state,
13561364
)
13571365
return AuxiliaryExperiment(experiment)

ax/storage/sqa_store/tests/test_sqa_store.py

+64-4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ax.core.outcome_constraint import OutcomeConstraint
3131
from ax.core.parameter import ParameterType, RangeParameter
3232
from ax.core.runner import Runner
33+
from ax.core.trial import Trial
3334
from ax.core.trial_status import TrialStatus
3435
from ax.core.types import ComparisonOp
3536
from ax.exceptions.core import ObjectNotFoundError
@@ -275,6 +276,8 @@ def test_saving_and_loading_experiment_with_aux_exp(self) -> None:
275276
is_test=True,
276277
)
277278
save_experiment(aux_experiment, config=self.config)
279+
# pyre-ignore[16]: `AuxiliaryExperimentPurpose` has no attribute
280+
purpose = self.config.auxiliary_experiment_purpose_enum.MyAuxExpPurpose
278281

279282
experiment_w_aux_exp = Experiment(
280283
name="test_experiment_w_aux_exp_in_SQAStoreTest",
@@ -284,10 +287,7 @@ def test_saving_and_loading_experiment_with_aux_exp(self) -> None:
284287
tracking_metrics=[Metric(name="tracking")],
285288
is_test=True,
286289
auxiliary_experiments_by_purpose={
287-
# pyre-ignore[16]: `AuxiliaryExperimentPurpose` has no attribute
288-
self.config.auxiliary_experiment_purpose_enum.MyAuxExpPurpose: [
289-
AuxiliaryExperiment(experiment=aux_experiment)
290-
]
290+
purpose: [AuxiliaryExperiment(experiment=aux_experiment)]
291291
},
292292
)
293293
self.assertIsNone(experiment_w_aux_exp.db_id)
@@ -299,6 +299,66 @@ def test_saving_and_loading_experiment_with_aux_exp(self) -> None:
299299
self.assertEqual(experiment_w_aux_exp, loaded_experiment)
300300
self.assertEqual(len(loaded_experiment.auxiliary_experiments_by_purpose), 1)
301301

302+
def test_saving_and_loading_experiment_with_aux_exp_reduced_state(self) -> None:
303+
aux_exp = Experiment(
304+
name="test_aux_exp_in_SQAStoreTest_reduced_state",
305+
search_space=get_search_space(),
306+
optimization_config=get_optimization_config(),
307+
description="test description",
308+
tracking_metrics=[Metric(name="tracking")],
309+
is_test=True,
310+
)
311+
aux_exp_gs = get_generation_strategy(with_callable_model_kwarg=False)
312+
aux_exp.new_trial(aux_exp_gs.gen(experiment=aux_exp))
313+
save_experiment(aux_exp, config=self.config)
314+
# pyre-ignore[16]: `AuxiliaryExperimentPurpose` has no attribute
315+
purpose = self.config.auxiliary_experiment_purpose_enum.MyAuxExpPurpose
316+
317+
target_exp = Experiment(
318+
name="test_experiment_w_aux_exp_in_SQAStoreTest_reduced_state",
319+
search_space=get_search_space(),
320+
optimization_config=get_optimization_config(),
321+
description="test description",
322+
tracking_metrics=[Metric(name="tracking")],
323+
is_test=True,
324+
auxiliary_experiments_by_purpose={
325+
purpose: [AuxiliaryExperiment(experiment=aux_exp)]
326+
},
327+
)
328+
target_exp_gs = get_generation_strategy(with_callable_model_kwarg=False)
329+
target_exp.new_trial(target_exp_gs.gen(experiment=target_exp))
330+
self.assertIsNone(target_exp.db_id)
331+
save_experiment(target_exp, config=self.config)
332+
self.assertIsNotNone(target_exp.db_id)
333+
loaded_target_exp = load_experiment(
334+
target_exp.name, config=self.config, reduced_state=True
335+
)
336+
self.assertNotEqual(target_exp, loaded_target_exp)
337+
self.assertIsNotNone( # State of the original aux experiment is not reduced.
338+
none_throws(
339+
assert_is_instance(aux_exp.trials[0], Trial).generator_run
340+
).gen_metadata
341+
)
342+
self.assertIsNotNone( # State of the original target experiment is not reduced.
343+
none_throws(
344+
assert_is_instance(target_exp.trials[0], Trial).generator_run
345+
).gen_metadata
346+
)
347+
self.assertIsNone( # State of the loaded target experiment *is reduced*.
348+
none_throws(
349+
assert_is_instance(loaded_target_exp.trials[0], Trial).generator_run
350+
).gen_metadata
351+
)
352+
loaded_aux_exp = loaded_target_exp.auxiliary_experiments_by_purpose[purpose][0]
353+
self.assertIsNone( # State of the loaded target experiment *is reduced*.
354+
none_throws(
355+
assert_is_instance(
356+
loaded_aux_exp.experiment.trials[0], Trial
357+
).generator_run
358+
).gen_metadata
359+
)
360+
self.assertEqual(len(loaded_target_exp.auxiliary_experiments_by_purpose), 1)
361+
302362
def test_saving_with_aux_exp_not_in_db(self) -> None:
303363
aux_experiment = Experiment(
304364
name="aux_experiment_not_in_db", search_space=get_search_space()

0 commit comments

Comments
 (0)