4848)
4949from ax .core .search_space import SearchSpace
5050from ax .core .trial import BaseTrial , Trial
51+ from ax .core .trial_status import TrialStatus
5152from ax .core .types import TParameterization , TParamValue
5253from ax .core .utils import get_model_times
5354from ax .early_stopping .strategies .base import BaseEarlyStoppingStrategy
@@ -161,6 +162,7 @@ def get_benchmark_runner(
161162def get_oracle_experiment_from_params (
162163 problem : BenchmarkProblem ,
163164 dict_of_dict_of_params : Mapping [int , Mapping [str , Mapping [str , TParamValue ]]],
165+ trial_statuses : Mapping [int , TrialStatus ] | None = None ,
164166) -> Experiment :
165167 """
166168 Get a new experiment with the same search space and optimization config
@@ -174,6 +176,12 @@ def get_oracle_experiment_from_params(
174176 config for generating an experiment.
175177 dict_of_dict_of_params: Keys are trial indices, values are Mappings
176178 (e.g. dicts) that map arm names to parameterizations.
179+ trial_statuses: Optional mapping from trial indices to their statuses.
180+ If provided, trials in oracle experiments will be set to the
181+ specified status.
182+ This helps preserve the trial status from the original experiment,
183+ especially if we want to take `ABANDONED` trials into account.
184+ If not provided, trials will be set to completed.
177185
178186 Example:
179187 >>> get_oracle_experiment_from_params(
@@ -219,11 +227,33 @@ def get_oracle_experiment_from_params(
219227 trial = experiment .trials [trial_index ]
220228 metadata = runner .run (trial = trial )
221229 trial .update_run_metadata (metadata = metadata )
222- trial .mark_completed ()
230+
231+ # Determine the status for the trial in the oracle experiment.
232+ # Mark ABANDONED and FAILED immediately (they don't require data).
233+ # EARLY_STOPPED requires data, so mark as completed for now and
234+ # defer the status change until after fetch_data().
235+ if trial_statuses is not None :
236+ status = trial_statuses [trial_index ]
237+ else :
238+ status = TrialStatus .COMPLETED
239+
240+ if status == TrialStatus .ABANDONED :
241+ trial .mark_abandoned ()
242+ elif status == TrialStatus .FAILED :
243+ trial .mark_failed ()
244+ else :
245+ trial .mark_completed ()
223246
224247 logger .setLevel (level = original_log_level )
225248
226249 experiment .fetch_data ()
250+
251+ # Apply EARLY_STOPPED status after data is available, since
252+ # mark_early_stopped() requires data on the trial.
253+ if trial_statuses is not None :
254+ for trial_index , status in trial_statuses .items ():
255+ if status == TrialStatus .EARLY_STOPPED :
256+ experiment .trials [trial_index ].mark_early_stopped (unsafe = True )
227257 return experiment
228258
229259
@@ -342,14 +372,15 @@ def get_inference_trace(
342372
343373def get_is_feasible_trace (
344374 experiment : Experiment , optimization_config : OptimizationConfig
345- ) -> list [float ]:
375+ ) -> list [bool ]:
346376 """Get a trace of feasibility for the experiment.
347377
348378 For batch trials we return True if any arm in a given batch is feasible.
379+ Trials without data (e.g. abandoned or failed) default to False.
349380 """
350381 df = experiment .lookup_data ().df .copy () # Let's not modify the original df
351382 if len (df ) == 0 :
352- return []
383+ return [False ] * len ( experiment . trials )
353384 # Derelativize the optimization config if needed.
354385 optimization_config = derelativize_opt_config (
355386 optimization_config = optimization_config ,
@@ -358,7 +389,11 @@ def get_is_feasible_trace(
358389 # Compute feasibility and return feasibility per group
359390 df = _prepare_data_for_trace (df = df , optimization_config = optimization_config )
360391 trial_grouped = df .groupby ("trial_index" )["feasible" ]
361- return trial_grouped .any ().tolist ()
392+ feasibility_by_trial = trial_grouped .any ().to_dict ()
393+ return [
394+ feasibility_by_trial .get (trial_index , False )
395+ for trial_index in sorted (experiment .trials .keys ())
396+ ]
362397
363398
364399def get_best_parameters (
@@ -455,8 +490,20 @@ def get_benchmark_result_from_experiment_and_gs(
455490 for new_trial_index , trials in enumerate (trial_completion_order )
456491 }
457492
493+ # Create trial_statuses mapping to preserve trial status in oracle experiment
494+ trial_statuses = {
495+ new_trial_index : (
496+ experiment .trials [next (iter (old_trial_indices ))].status
497+ if len (old_trial_indices ) == 1
498+ else TrialStatus .COMPLETED
499+ )
500+ for new_trial_index , old_trial_indices in enumerate (trial_completion_order )
501+ }
502+
458503 actual_params_oracle_dummy_experiment = get_oracle_experiment_from_params (
459- problem = problem , dict_of_dict_of_params = dict_of_dict_of_params
504+ problem = problem ,
505+ dict_of_dict_of_params = dict_of_dict_of_params ,
506+ trial_statuses = trial_statuses ,
460507 )
461508 oracle_trace = np .array (
462509 get_trace (
0 commit comments