4848)
4949from ax .core .search_space import SearchSpace
5050from ax .core .trial import BaseTrial , Trial
51+ from ax .core .trial_status import TrialStatus
5152from ax .core .types import TParameterization , TParamValue
5253from ax .core .utils import get_model_times
5354from ax .early_stopping .strategies .base import BaseEarlyStoppingStrategy
@@ -157,6 +158,7 @@ def get_benchmark_runner(
157158def get_oracle_experiment_from_params (
158159 problem : BenchmarkProblem ,
159160 dict_of_dict_of_params : Mapping [int , Mapping [str , Mapping [str , TParamValue ]]],
161+ trial_statuses : Mapping [int , TrialStatus ] | None = None ,
160162) -> Experiment :
161163 """
162164 Get a new experiment with the same search space and optimization config
@@ -170,6 +172,12 @@ def get_oracle_experiment_from_params(
170172 config for generating an experiment.
171173 dict_of_dict_of_params: Keys are trial indices, values are Mappings
172174 (e.g. dicts) that map arm names to parameterizations.
175+ trial_statuses: Optional mapping from trial indices to their statuses.
176+ If provided, trials in oracle experiments will be set to the
177+ specified status.
178+ This helps preserve the trial status from the original experiment,
179+ especially if we want to take `ABANDONED` trials into account.
180+ If not provided, trials will be set to completed.
173181
174182 Example:
175183 >>> get_oracle_experiment_from_params(
@@ -215,11 +223,33 @@ def get_oracle_experiment_from_params(
215223 trial = experiment .trials [trial_index ]
216224 metadata = runner .run (trial = trial )
217225 trial .update_run_metadata (metadata = metadata )
218- trial .mark_completed ()
226+
227+ # Determine the status for the trial in the oracle experiment.
228+ # Mark ABANDONED and FAILED immediately (they don't require data).
229+ # EARLY_STOPPED requires data, so mark as completed for now and
230+ # defer the status change until after fetch_data().
231+ if trial_statuses is not None :
232+ status = trial_statuses [trial_index ]
233+ else :
234+ status = TrialStatus .COMPLETED
235+
236+ if status == TrialStatus .ABANDONED :
237+ trial .mark_abandoned ()
238+ elif status == TrialStatus .FAILED :
239+ trial .mark_failed ()
240+ else :
241+ trial .mark_completed ()
219242
220243 logger .setLevel (level = original_log_level )
221244
222245 experiment .fetch_data ()
246+
247+ # Apply EARLY_STOPPED status after data is available, since
248+ # mark_early_stopped() requires data on the trial.
249+ if trial_statuses is not None :
250+ for trial_index , status in trial_statuses .items ():
251+ if status == TrialStatus .EARLY_STOPPED :
252+ experiment .trials [trial_index ].mark_early_stopped (unsafe = True )
223253 return experiment
224254
225255
@@ -451,13 +481,22 @@ def get_benchmark_result_from_experiment_and_gs(
451481 for new_trial_index , trials in enumerate (trial_completion_order )
452482 }
453483
484+ # Create trial_statuses mapping to preserve trial status in oracle experiment
485+ trial_statuses = {
486+ trial_index : experiment .trials [trial_index ].status
487+ for trial_index in dict_of_dict_of_params .keys ()
488+ }
489+
454490 actual_params_oracle_dummy_experiment = get_oracle_experiment_from_params (
455- problem = problem , dict_of_dict_of_params = dict_of_dict_of_params
491+ problem = problem ,
492+ dict_of_dict_of_params = dict_of_dict_of_params ,
493+ trial_statuses = trial_statuses ,
456494 )
457495 oracle_trace = np .array (
458496 get_trace (
459497 experiment = actual_params_oracle_dummy_experiment ,
460498 optimization_config = problem .optimization_config ,
499+ include_abandoned = True ,
461500 )
462501 )
463502 is_feasible_trace = np .array (
0 commit comments