@@ -197,7 +197,8 @@ def _init_experiment_from_sqa(
197
197
)
198
198
199
199
def _init_mt_experiment_from_sqa (
200
- self , experiment_sqa : SQAExperiment
200
+ self ,
201
+ experiment_sqa : SQAExperiment ,
201
202
) -> MultiTypeExperiment :
202
203
"""First step of conversion within experiment_from_sqa."""
203
204
opt_config , tracking_metrics = self .opt_config_and_tracking_metrics_from_sqa (
@@ -217,24 +218,41 @@ def _init_mt_experiment_from_sqa(
217
218
if experiment_sqa .status_quo_parameters is not None
218
219
else None
219
220
)
221
+
222
+ default_trial_type = none_throws (experiment_sqa .default_trial_type )
220
223
trial_type_to_runner = {
221
224
none_throws (sqa_runner .trial_type ): self .runner_from_sqa (sqa_runner )
222
225
for sqa_runner in experiment_sqa .runners
223
226
}
224
- default_trial_type = none_throws (experiment_sqa .default_trial_type )
227
+ if len (trial_type_to_runner ) == 0 :
228
+ trial_type_to_runner = {default_trial_type : None }
229
+ trial_types_with_metrics = {
230
+ metric .trial_type
231
+ for metric in experiment_sqa .metrics
232
+ if metric .trial_type
233
+ }
234
+ # trial_type_to_runner is instantiated to map all trial types to None,
235
+ # so the trial types are associated with the expeirment. This is
236
+ # important for adding metrics.
237
+ trial_type_to_runner .update (
238
+ {t_type : None for t_type in trial_types_with_metrics }
239
+ )
225
240
properties = dict (experiment_sqa .properties or {})
226
241
default_data_type = experiment_sqa .default_data_type
227
242
experiment = MultiTypeExperiment (
228
243
name = experiment_sqa .name ,
229
244
description = experiment_sqa .description ,
230
245
search_space = search_space ,
231
246
default_trial_type = default_trial_type ,
232
- default_runner = trial_type_to_runner [ default_trial_type ] ,
247
+ default_runner = trial_type_to_runner . get ( default_trial_type ) ,
233
248
optimization_config = opt_config ,
234
249
status_quo = status_quo ,
235
250
properties = properties ,
236
251
default_data_type = default_data_type ,
237
252
)
253
+ # pyre-ignore Imcompatible attribute type [8]: attribute _trial_type_to_runner
254
+ # has type Dict[str, Optional[Runner]] but is used as type
255
+ # Uniont[Dict[str, Optional[Runner]], Dict[str, None]]
238
256
experiment ._trial_type_to_runner = trial_type_to_runner
239
257
sqa_metric_dict = {metric .name : metric for metric in experiment_sqa .metrics }
240
258
for tracking_metric in tracking_metrics :
0 commit comments