|
11 | 11 | from enum import Enum
|
12 | 12 | from io import StringIO
|
13 | 13 | from logging import Logger
|
14 |
| -from typing import cast, Union |
| 14 | +from typing import Any, cast, Union |
15 | 15 |
|
16 | 16 | import pandas as pd
|
17 | 17 | from ax.analysis.analysis import AnalysisCard
|
@@ -108,29 +108,25 @@ def _auxiliary_experiments_by_purpose_from_experiment_sqa(
|
108 | 108 | ) -> dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]] | None:
|
109 | 109 | auxiliary_experiments_by_purpose = None
|
110 | 110 | if experiment_sqa.auxiliary_experiments_by_purpose:
|
111 |
| - from ax.storage.sqa_store.load import load_experiment |
112 |
| - |
113 | 111 | auxiliary_experiments_by_purpose = {}
|
114 |
| - aux_exp_name_dict = none_throws( |
115 |
| - experiment_sqa.auxiliary_experiments_by_purpose |
116 |
| - ) |
117 |
| - for aux_exp_purpose_str, aux_exp_names in aux_exp_name_dict.items(): |
| 112 | + aux_exps_dict = none_throws(experiment_sqa.auxiliary_experiments_by_purpose) |
| 113 | + for aux_exp_purpose_str, aux_exps_json in aux_exps_dict.items(): |
118 | 114 | aux_exp_purpose = next(
|
119 | 115 | member
|
120 | 116 | for member in self.config.auxiliary_experiment_purpose_enum
|
121 | 117 | if member.value == aux_exp_purpose_str
|
122 | 118 | )
|
123 | 119 | auxiliary_experiments_by_purpose[aux_exp_purpose] = []
|
124 |
| - for aux_exp_name in aux_exp_names: |
| 120 | + for aux_exp_json in aux_exps_json: |
| 121 | + # keeping this for backward compatibility since previously |
| 122 | + # we used to save only the experiment name |
| 123 | + if isinstance(aux_exp_json, str): |
| 124 | + aux_exp_json = {"experiment_name": aux_exp_json} |
| 125 | + aux_experiment = auxiliary_experiment_from_json( |
| 126 | + json=aux_exp_json, config=self.config |
| 127 | + ) |
125 | 128 | auxiliary_experiments_by_purpose[aux_exp_purpose].append(
|
126 |
| - AuxiliaryExperiment( |
127 |
| - experiment=load_experiment( |
128 |
| - aux_exp_name, |
129 |
| - config=self.config, |
130 |
| - skip_runners_and_metrics=True, |
131 |
| - load_auxiliary_experiments=False, |
132 |
| - ) |
133 |
| - ) |
| 129 | + aux_experiment |
134 | 130 | )
|
135 | 131 | return auxiliary_experiments_by_purpose
|
136 | 132 |
|
@@ -1321,3 +1317,28 @@ def _get_scalarized_outcome_constraint_children_metrics(
|
1321 | 1317 | )
|
1322 | 1318 | metrics_sqa = query.all()
|
1323 | 1319 | return metrics_sqa
|
| 1320 | + |
| 1321 | + |
| 1322 | +def auxiliary_experiment_from_json( |
| 1323 | + json: dict[str, Any], config: SQAConfig |
| 1324 | +) -> AuxiliaryExperiment: |
| 1325 | + """ |
| 1326 | + Load an ``AuxiliaryExperiment`` from JSON. |
| 1327 | +
|
| 1328 | + Args: |
| 1329 | + json: A dictionary containing the JSON representation of an AuxiliaryExperiment. |
| 1330 | + config: The SQAConfig object used to load the experiment. |
| 1331 | +
|
| 1332 | + Returns: |
| 1333 | + An AuxiliaryExperiment object constructed from the JSON representation. |
| 1334 | + """ |
| 1335 | + |
| 1336 | + from ax.storage.sqa_store.load import load_experiment |
| 1337 | + |
| 1338 | + experiment = load_experiment( |
| 1339 | + json.get("experiment_name"), |
| 1340 | + config=config, |
| 1341 | + skip_runners_and_metrics=True, |
| 1342 | + load_auxiliary_experiments=False, |
| 1343 | + ) |
| 1344 | + return AuxiliaryExperiment(experiment) |
0 commit comments