4848)
4949from ax .core .search_space import SearchSpace
5050from ax .core .trial import BaseTrial , Trial
51+ from ax .core .trial_status import TrialStatus
5152from ax .core .types import TParameterization , TParamValue
5253from ax .core .utils import get_model_times
5354from ax .early_stopping .strategies .base import BaseEarlyStoppingStrategy
@@ -157,6 +158,7 @@ def get_benchmark_runner(
157158def get_oracle_experiment_from_params (
158159 problem : BenchmarkProblem ,
159160 dict_of_dict_of_params : Mapping [int , Mapping [str , Mapping [str , TParamValue ]]],
161+ trial_statuses : Mapping [int , TrialStatus ] | None = None ,
160162) -> Experiment :
161163 """
162164 Get a new experiment with the same search space and optimization config
@@ -170,6 +172,12 @@ def get_oracle_experiment_from_params(
170172 config for generating an experiment.
171173 dict_of_dict_of_params: Keys are trial indices, values are Mappings
172174 (e.g. dicts) that map arm names to parameterizations.
175+ trial_statuses: Optional mapping from trial indices to their statuses.
176+ If provided, trials in oracle experiments will be set to the
177+ specified status.
178+ This helps preserve the trial status from the original experiment,
179+ especially if we want to take `ABANDONED` trials into account.
180+ If not provided, trials will be set to completed.
173181
174182 Example:
175183 >>> get_oracle_experiment_from_params(
@@ -215,11 +223,33 @@ def get_oracle_experiment_from_params(
215223 trial = experiment .trials [trial_index ]
216224 metadata = runner .run (trial = trial )
217225 trial .update_run_metadata (metadata = metadata )
218- trial .mark_completed ()
226+
227+ # Determine the status for the trial in the oracle experiment.
228+ # Mark ABANDONED and FAILED immediately (they don't require data).
229+ # EARLY_STOPPED requires data, so mark as completed for now and
230+ # defer the status change until after fetch_data().
231+ if trial_statuses is not None :
232+ status = trial_statuses [trial_index ]
233+ else :
234+ status = TrialStatus .COMPLETED
235+
236+ if status == TrialStatus .ABANDONED :
237+ trial .mark_abandoned ()
238+ elif status == TrialStatus .FAILED :
239+ trial .mark_failed ()
240+ else :
241+ trial .mark_completed ()
219242
220243 logger .setLevel (level = original_log_level )
221244
222245 experiment .fetch_data ()
246+
247+ # Apply EARLY_STOPPED status after data is available, since
248+ # mark_early_stopped() requires data on the trial.
249+ if trial_statuses is not None :
250+ for trial_index , status in trial_statuses .items ():
251+ if status == TrialStatus .EARLY_STOPPED :
252+ experiment .trials [trial_index ].mark_early_stopped (unsafe = True )
223253 return experiment
224254
225255
@@ -338,14 +368,15 @@ def get_inference_trace(
338368
339369def get_is_feasible_trace (
340370 experiment : Experiment , optimization_config : OptimizationConfig
341- ) -> list [float ]:
371+ ) -> list [bool ]:
342372 """Get a trace of feasibility for the experiment.
343373
344374 For batch trials we return True if any arm in a given batch is feasible.
375+ Trials without data (e.g. abandoned or failed) default to False.
345376 """
346377 df = experiment .lookup_data ().df .copy () # Let's not modify the original df
347378 if len (df ) == 0 :
348- return []
379+ return [False ] * len ( experiment . trials )
349380 # Derelativize the optimization config if needed.
350381 optimization_config = derelativize_opt_config (
351382 optimization_config = optimization_config ,
@@ -354,7 +385,11 @@ def get_is_feasible_trace(
354385 # Compute feasibility and return feasibility per group
355386 df = _prepare_data_for_trace (df = df , optimization_config = optimization_config )
356387 trial_grouped = df .groupby ("trial_index" )["feasible" ]
357- return trial_grouped .any ().tolist ()
388+ feasibility_by_trial = trial_grouped .any ().to_dict ()
389+ return [
390+ feasibility_by_trial .get (trial_index , False )
391+ for trial_index in sorted (experiment .trials .keys ())
392+ ]
358393
359394
360395def get_best_parameters (
@@ -451,8 +486,20 @@ def get_benchmark_result_from_experiment_and_gs(
451486 for new_trial_index , trials in enumerate (trial_completion_order )
452487 }
453488
489+ # Create trial_statuses mapping to preserve trial status in oracle experiment
490+ trial_statuses = {
491+ new_trial_index : (
492+ experiment .trials [next (iter (old_trial_indices ))].status
493+ if len (old_trial_indices ) == 1
494+ else TrialStatus .COMPLETED
495+ )
496+ for new_trial_index , old_trial_indices in enumerate (trial_completion_order )
497+ }
498+
454499 actual_params_oracle_dummy_experiment = get_oracle_experiment_from_params (
455- problem = problem , dict_of_dict_of_params = dict_of_dict_of_params
500+ problem = problem ,
501+ dict_of_dict_of_params = dict_of_dict_of_params ,
502+ trial_statuses = trial_statuses ,
456503 )
457504 oracle_trace = np .array (
458505 get_trace (
0 commit comments