From cb1a304e7d8ffd71bea7cc72d9ee52728087df9f Mon Sep 17 00:00:00 2001 From: Sunny Shen Date: Wed, 11 Mar 2026 13:15:15 -0700 Subject: [PATCH] Allow Oracle Experiment to take ABANDONED trials into account (#4953) Summary: Include ABANDONED trials in the trace by carrying forward the last best value. This ensures the trace has one value per trial, reflecting that ABANDONED trials consumed resources but didn't improve optimization. Reviewed By: saitcakmak Differential Revision: D86833965 --- ax/benchmark/benchmark.py | 57 +++++++++++++++++++++++++--- ax/benchmark/tests/test_benchmark.py | 47 +++++++++++++++++++++++ ax/service/tests/test_best_point.py | 45 +++++++++++++++++++++- ax/service/utils/best_point.py | 33 +++++++++++++++- 4 files changed, 174 insertions(+), 8 deletions(-) diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index 2b3edced8fa..b698bfdbe24 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -48,6 +48,7 @@ ) from ax.core.search_space import SearchSpace from ax.core.trial import BaseTrial, Trial +from ax.core.trial_status import TrialStatus from ax.core.types import TParameterization, TParamValue from ax.core.utils import get_model_times from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy @@ -161,6 +162,7 @@ def get_benchmark_runner( def get_oracle_experiment_from_params( problem: BenchmarkProblem, dict_of_dict_of_params: Mapping[int, Mapping[str, Mapping[str, TParamValue]]], + trial_statuses: Mapping[int, TrialStatus] | None = None, ) -> Experiment: """ Get a new experiment with the same search space and optimization config @@ -174,6 +176,12 @@ def get_oracle_experiment_from_params( config for generating an experiment. dict_of_dict_of_params: Keys are trial indices, values are Mappings (e.g. dicts) that map arm names to parameterizations. + trial_statuses: Optional mapping from trial indices to their statuses. + If provided, trials in oracle experiments will be set to the + specified status. + This helps preserve the trial status from the original experiment, + especially if we want to take `ABANDONED` trials into account. + If not provided, trials will be set to completed. Example: >>> get_oracle_experiment_from_params( @@ -219,11 +227,33 @@ def get_oracle_experiment_from_params( trial = experiment.trials[trial_index] metadata = runner.run(trial=trial) trial.update_run_metadata(metadata=metadata) - trial.mark_completed() + + # Determine the status for the trial in the oracle experiment. + # Mark ABANDONED and FAILED immediately (they don't require data). + # EARLY_STOPPED requires data, so mark as completed for now and + # defer the status change until after fetch_data(). + if trial_statuses is not None: + status = trial_statuses[trial_index] + else: + status = TrialStatus.COMPLETED + + if status == TrialStatus.ABANDONED: + trial.mark_abandoned() + elif status == TrialStatus.FAILED: + trial.mark_failed() + else: + trial.mark_completed() logger.setLevel(level=original_log_level) experiment.fetch_data() + + # Apply EARLY_STOPPED status after data is available, since + # mark_early_stopped() requires data on the trial. + if trial_statuses is not None: + for trial_index, status in trial_statuses.items(): + if status == TrialStatus.EARLY_STOPPED: + experiment.trials[trial_index].mark_early_stopped(unsafe=True) return experiment @@ -342,14 +372,15 @@ def get_inference_trace( def get_is_feasible_trace( experiment: Experiment, optimization_config: OptimizationConfig -) -> list[float]: +) -> list[bool]: """Get a trace of feasibility for the experiment. For batch trials we return True if any arm in a given batch is feasible. + Trials without data (e.g. abandoned or failed) default to False. """ df = experiment.lookup_data().df.copy() # Let's not modify the original df if len(df) == 0: - return [] + return [False] * len(experiment.trials) # Derelativize the optimization config if needed. optimization_config = derelativize_opt_config( optimization_config=optimization_config, @@ -358,7 +389,11 @@ def get_is_feasible_trace( # Compute feasibility and return feasibility per group df = _prepare_data_for_trace(df=df, optimization_config=optimization_config) trial_grouped = df.groupby("trial_index")["feasible"] - return trial_grouped.any().tolist() + feasibility_by_trial = trial_grouped.any().to_dict() + return [ + feasibility_by_trial.get(trial_index, False) + for trial_index in sorted(experiment.trials.keys()) + ] def get_best_parameters( @@ -455,8 +490,20 @@ def get_benchmark_result_from_experiment_and_gs( for new_trial_index, trials in enumerate(trial_completion_order) } + # Create trial_statuses mapping to preserve trial status in oracle experiment. + # If all trials in a completion group share the same status, use that status; + # otherwise default to COMPLETED. + trial_statuses = {} + for new_trial_index, old_trial_indices in enumerate(trial_completion_order): + statuses = {experiment.trials[idx].status for idx in old_trial_indices} + trial_statuses[new_trial_index] = ( + next(iter(statuses)) if len(statuses) == 1 else TrialStatus.COMPLETED + ) + actual_params_oracle_dummy_experiment = get_oracle_experiment_from_params( - problem=problem, dict_of_dict_of_params=dict_of_dict_of_params + problem=problem, + dict_of_dict_of_params=dict_of_dict_of_params, + trial_statuses=trial_statuses, ) oracle_trace = np.array( get_trace( diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index b10b181a2e7..37e01fdf186 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -70,6 +70,7 @@ ) from ax.core.experiment import Experiment from ax.core.objective import MultiObjective +from ax.core.trial_status import TrialStatus from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy from ax.generation_strategy.external_generation_node import ExternalGenerationNode from ax.generation_strategy.generation_strategy import ( @@ -1014,6 +1015,52 @@ def test_get_oracle_experiment_from_params(self) -> None: problem=problem, dict_of_dict_of_params={0: {}} ) + with self.subTest("trial_statuses"): + trial_statuses = { + 0: TrialStatus.COMPLETED, + 1: TrialStatus.ABANDONED, + } + experiment = get_oracle_experiment_from_params( + problem=problem, + dict_of_dict_of_params={ + 0: {"0": near_opt_params}, + 1: {"1": other_params}, + }, + trial_statuses=trial_statuses, + ) + self.assertEqual(len(experiment.trials), 2) + self.assertTrue(experiment.trials[0].status.is_completed) + self.assertEqual(experiment.trials[1].status, TrialStatus.ABANDONED) + + with self.subTest("trial_statuses with FAILED and EARLY_STOPPED"): + trial_statuses = { + 0: TrialStatus.FAILED, + 1: TrialStatus.EARLY_STOPPED, + } + experiment = get_oracle_experiment_from_params( + problem=problem, + dict_of_dict_of_params={ + 0: {"0": near_opt_params}, + 1: {"1": other_params}, + }, + trial_statuses=trial_statuses, + ) + self.assertEqual(experiment.trials[0].status, TrialStatus.FAILED) + self.assertEqual(experiment.trials[1].status, TrialStatus.EARLY_STOPPED) + + with self.subTest("trial_statuses=None defaults to COMPLETED"): + experiment = get_oracle_experiment_from_params( + problem=problem, + dict_of_dict_of_params={ + 0: {"0": near_opt_params}, + 1: {"1": other_params}, + }, + trial_statuses=None, + ) + self.assertTrue( + all(t.status.is_completed for t in experiment.trials.values()) + ) + def _test_multi_fidelity_or_multi_task( self, fidelity_or_task: Literal["fidelity", "task"] ) -> None: diff --git a/ax/service/tests/test_best_point.py b/ax/service/tests/test_best_point.py index 8972dd292a8..33ab5338512 100644 --- a/ax/service/tests/test_best_point.py +++ b/ax/service/tests/test_best_point.py @@ -189,7 +189,50 @@ def test_get_trace(self) -> None: ] ) exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict2))) - self.assertEqual(get_trace(exp), [2.0, 20.0]) + self.assertEqual(get_trace(exp), [2.0, 2.0, 20.0]) + + def test_get_trace_with_non_completed_trials(self) -> None: + with self.subTest("minimize with abandoned trial"): + exp = get_experiment_with_observations( + observations=[[11], [10], [9], [15], [5]], minimize=True + ) + # Mark trial 2 (value=9) as abandoned + exp.trials[2].mark_abandoned(unsafe=True) + + # Abandoned trial carries forward the last best value + trace = get_trace(exp) + self.assertEqual(len(trace), 5) + # Trial 0: 11, Trial 1: 10, Trial 2 (abandoned): carry forward 10 + # Trial 3: 10 (15 > 10), Trial 4: 5 + self.assertEqual(trace, [11, 10, 10, 10, 5]) + + with self.subTest("maximize with abandoned trial"): + exp = get_experiment_with_observations( + observations=[[1], [3], [2], [5], [4]], minimize=False + ) + # Mark trial 1 (value=3) as abandoned + exp.trials[1].mark_abandoned(unsafe=True) + + # Abandoned trial carries forward the last best value + trace = get_trace(exp) + self.assertEqual(len(trace), 5) + # Trial 0: 1, Trial 1 (abandoned): carry forward 1, + # Trial 2: 2, Trial 3: 5, Trial 4: 5 + self.assertEqual(trace, [1, 1, 2, 5, 5]) + + with self.subTest("minimize with failed trial"): + exp = get_experiment_with_observations( + observations=[[11], [10], [9], [15], [5]], minimize=True + ) + # Mark trial 2 (value=9) as failed + exp.trials[2].mark_failed(unsafe=True) + + # Failed trial carries forward the last best value + trace = get_trace(exp) + self.assertEqual(len(trace), 5) + # Trial 0: 11, Trial 1: 10, Trial 2 (failed): carry forward 10 + # Trial 3: 10 (15 > 10), Trial 4: 5 + self.assertEqual(trace, [11, 10, 10, 10, 5]) def test_get_trace_with_include_status_quo(self) -> None: with self.subTest("Multi-objective: status quo dominates in some trials"): diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index 5a2cc15de70..be96ef6d81d 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -1219,6 +1219,8 @@ def get_trace( An iteration here refers to a completed or early-stopped (batch) trial. There will be one performance metric in the trace for each iteration. + Trials without data (e.g. abandoned or failed) carry forward the last + best value. Args: experiment: The experiment to get the trace for. @@ -1278,12 +1280,39 @@ def get_trace( # Aggregate by trial, then. compute cumulative best objective = optimization_config.objective maximize = isinstance(objective, MultiObjective) or not objective.minimize - return _aggregate_and_cumulate_trace( + cumulative_value = _aggregate_and_cumulate_trace( df=value_by_arm_pull, by=["trial_index"], maximize=maximize, keep_order=False, # sort by trial index - ).tolist() + ) + + compact_trace = cumulative_value.tolist() + + # Expand trace to include trials without data (e.g. ABANDONED, FAILED) + # with carry-forward values. + data_trial_indices = set(cumulative_value.index) + expanded_trace = [] + compact_idx = 0 + last_best_value = -float("inf") if maximize else float("inf") + + for trial_index in sorted(experiment.trials.keys()): + trial = experiment.trials[trial_index] + if trial_index in data_trial_indices: + # Trial has data in compact trace + if compact_idx < len(compact_trace): + value = compact_trace[compact_idx] + expanded_trace.append(value) + last_best_value = value + compact_idx += 1 + else: + # Should not happen, but handle gracefully + expanded_trace.append(last_best_value) + elif trial.status in (TrialStatus.ABANDONED, TrialStatus.FAILED): + # Trial has no data; carry forward the last best value. + expanded_trace.append(last_best_value) + + return expanded_trace def get_tensor_converter_adapter(