diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index 2b3edced8fa..b698bfdbe24 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -48,6 +48,7 @@ ) from ax.core.search_space import SearchSpace from ax.core.trial import BaseTrial, Trial +from ax.core.trial_status import TrialStatus from ax.core.types import TParameterization, TParamValue from ax.core.utils import get_model_times from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy @@ -161,6 +162,7 @@ def get_benchmark_runner( def get_oracle_experiment_from_params( problem: BenchmarkProblem, dict_of_dict_of_params: Mapping[int, Mapping[str, Mapping[str, TParamValue]]], + trial_statuses: Mapping[int, TrialStatus] | None = None, ) -> Experiment: """ Get a new experiment with the same search space and optimization config @@ -174,6 +176,12 @@ def get_oracle_experiment_from_params( config for generating an experiment. dict_of_dict_of_params: Keys are trial indices, values are Mappings (e.g. dicts) that map arm names to parameterizations. + trial_statuses: Optional mapping from trial indices to their statuses. + If provided, trials in oracle experiments will be set to the + specified status. + This helps preserve the trial status from the original experiment, + especially if we want to take `ABANDONED` trials into account. + If not provided, trials will be set to completed. Example: >>> get_oracle_experiment_from_params( @@ -219,11 +227,33 @@ def get_oracle_experiment_from_params( trial = experiment.trials[trial_index] metadata = runner.run(trial=trial) trial.update_run_metadata(metadata=metadata) - trial.mark_completed() + + # Determine the status for the trial in the oracle experiment. + # Mark ABANDONED and FAILED immediately (they don't require data). + # EARLY_STOPPED requires data, so mark as completed for now and + # defer the status change until after fetch_data(). + if trial_statuses is not None: + status = trial_statuses[trial_index] + else: + status = TrialStatus.COMPLETED + + if status == TrialStatus.ABANDONED: + trial.mark_abandoned() + elif status == TrialStatus.FAILED: + trial.mark_failed() + else: + trial.mark_completed() logger.setLevel(level=original_log_level) experiment.fetch_data() + + # Apply EARLY_STOPPED status after data is available, since + # mark_early_stopped() requires data on the trial. + if trial_statuses is not None: + for trial_index, status in trial_statuses.items(): + if status == TrialStatus.EARLY_STOPPED: + experiment.trials[trial_index].mark_early_stopped(unsafe=True) return experiment @@ -342,14 +372,15 @@ def get_inference_trace( def get_is_feasible_trace( experiment: Experiment, optimization_config: OptimizationConfig -) -> list[float]: +) -> list[bool]: """Get a trace of feasibility for the experiment. For batch trials we return True if any arm in a given batch is feasible. + Trials without data (e.g. abandoned or failed) default to False. """ df = experiment.lookup_data().df.copy() # Let's not modify the original df if len(df) == 0: - return [] + return [False] * len(experiment.trials) # Derelativize the optimization config if needed. optimization_config = derelativize_opt_config( optimization_config=optimization_config, @@ -358,7 +389,11 @@ def get_is_feasible_trace( # Compute feasibility and return feasibility per group df = _prepare_data_for_trace(df=df, optimization_config=optimization_config) trial_grouped = df.groupby("trial_index")["feasible"] - return trial_grouped.any().tolist() + feasibility_by_trial = trial_grouped.any().to_dict() + return [ + feasibility_by_trial.get(trial_index, False) + for trial_index in sorted(experiment.trials.keys()) + ] def get_best_parameters( @@ -455,8 +490,20 @@ def get_benchmark_result_from_experiment_and_gs( for new_trial_index, trials in enumerate(trial_completion_order) } + # Create trial_statuses mapping to preserve trial status in oracle experiment. + # If all trials in a completion group share the same status, use that status; + # otherwise default to COMPLETED. + trial_statuses = {} + for new_trial_index, old_trial_indices in enumerate(trial_completion_order): + statuses = {experiment.trials[idx].status for idx in old_trial_indices} + trial_statuses[new_trial_index] = ( + next(iter(statuses)) if len(statuses) == 1 else TrialStatus.COMPLETED + ) + actual_params_oracle_dummy_experiment = get_oracle_experiment_from_params( - problem=problem, dict_of_dict_of_params=dict_of_dict_of_params + problem=problem, + dict_of_dict_of_params=dict_of_dict_of_params, + trial_statuses=trial_statuses, ) oracle_trace = np.array( get_trace( diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index b10b181a2e7..37e01fdf186 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -70,6 +70,7 @@ ) from ax.core.experiment import Experiment from ax.core.objective import MultiObjective +from ax.core.trial_status import TrialStatus from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy from ax.generation_strategy.external_generation_node import ExternalGenerationNode from ax.generation_strategy.generation_strategy import ( @@ -1014,6 +1015,52 @@ def test_get_oracle_experiment_from_params(self) -> None: problem=problem, dict_of_dict_of_params={0: {}} ) + with self.subTest("trial_statuses"): + trial_statuses = { + 0: TrialStatus.COMPLETED, + 1: TrialStatus.ABANDONED, + } + experiment = get_oracle_experiment_from_params( + problem=problem, + dict_of_dict_of_params={ + 0: {"0": near_opt_params}, + 1: {"1": other_params}, + }, + trial_statuses=trial_statuses, + ) + self.assertEqual(len(experiment.trials), 2) + self.assertTrue(experiment.trials[0].status.is_completed) + self.assertEqual(experiment.trials[1].status, TrialStatus.ABANDONED) + + with self.subTest("trial_statuses with FAILED and EARLY_STOPPED"): + trial_statuses = { + 0: TrialStatus.FAILED, + 1: TrialStatus.EARLY_STOPPED, + } + experiment = get_oracle_experiment_from_params( + problem=problem, + dict_of_dict_of_params={ + 0: {"0": near_opt_params}, + 1: {"1": other_params}, + }, + trial_statuses=trial_statuses, + ) + self.assertEqual(experiment.trials[0].status, TrialStatus.FAILED) + self.assertEqual(experiment.trials[1].status, TrialStatus.EARLY_STOPPED) + + with self.subTest("trial_statuses=None defaults to COMPLETED"): + experiment = get_oracle_experiment_from_params( + problem=problem, + dict_of_dict_of_params={ + 0: {"0": near_opt_params}, + 1: {"1": other_params}, + }, + trial_statuses=None, + ) + self.assertTrue( + all(t.status.is_completed for t in experiment.trials.values()) + ) + def _test_multi_fidelity_or_multi_task( self, fidelity_or_task: Literal["fidelity", "task"] ) -> None: diff --git a/ax/service/tests/test_best_point.py b/ax/service/tests/test_best_point.py index 8972dd292a8..33ab5338512 100644 --- a/ax/service/tests/test_best_point.py +++ b/ax/service/tests/test_best_point.py @@ -189,7 +189,50 @@ def test_get_trace(self) -> None: ] ) exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict2))) - self.assertEqual(get_trace(exp), [2.0, 20.0]) + self.assertEqual(get_trace(exp), [2.0, 2.0, 20.0]) + + def test_get_trace_with_non_completed_trials(self) -> None: + with self.subTest("minimize with abandoned trial"): + exp = get_experiment_with_observations( + observations=[[11], [10], [9], [15], [5]], minimize=True + ) + # Mark trial 2 (value=9) as abandoned + exp.trials[2].mark_abandoned(unsafe=True) + + # Abandoned trial carries forward the last best value + trace = get_trace(exp) + self.assertEqual(len(trace), 5) + # Trial 0: 11, Trial 1: 10, Trial 2 (abandoned): carry forward 10 + # Trial 3: 10 (15 > 10), Trial 4: 5 + self.assertEqual(trace, [11, 10, 10, 10, 5]) + + with self.subTest("maximize with abandoned trial"): + exp = get_experiment_with_observations( + observations=[[1], [3], [2], [5], [4]], minimize=False + ) + # Mark trial 1 (value=3) as abandoned + exp.trials[1].mark_abandoned(unsafe=True) + + # Abandoned trial carries forward the last best value + trace = get_trace(exp) + self.assertEqual(len(trace), 5) + # Trial 0: 1, Trial 1 (abandoned): carry forward 1, + # Trial 2: 2, Trial 3: 5, Trial 4: 5 + self.assertEqual(trace, [1, 1, 2, 5, 5]) + + with self.subTest("minimize with failed trial"): + exp = get_experiment_with_observations( + observations=[[11], [10], [9], [15], [5]], minimize=True + ) + # Mark trial 2 (value=9) as failed + exp.trials[2].mark_failed(unsafe=True) + + # Failed trial carries forward the last best value + trace = get_trace(exp) + self.assertEqual(len(trace), 5) + # Trial 0: 11, Trial 1: 10, Trial 2 (failed): carry forward 10 + # Trial 3: 10 (15 > 10), Trial 4: 5 + self.assertEqual(trace, [11, 10, 10, 10, 5]) def test_get_trace_with_include_status_quo(self) -> None: with self.subTest("Multi-objective: status quo dominates in some trials"): diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index 5a2cc15de70..be96ef6d81d 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -1219,6 +1219,8 @@ def get_trace( An iteration here refers to a completed or early-stopped (batch) trial. There will be one performance metric in the trace for each iteration. + Trials without data (e.g. abandoned or failed) carry forward the last + best value. Args: experiment: The experiment to get the trace for. @@ -1278,12 +1280,39 @@ def get_trace( # Aggregate by trial, then. compute cumulative best objective = optimization_config.objective maximize = isinstance(objective, MultiObjective) or not objective.minimize - return _aggregate_and_cumulate_trace( + cumulative_value = _aggregate_and_cumulate_trace( df=value_by_arm_pull, by=["trial_index"], maximize=maximize, keep_order=False, # sort by trial index - ).tolist() + ) + + compact_trace = cumulative_value.tolist() + + # Expand trace to include trials without data (e.g. ABANDONED, FAILED) + # with carry-forward values. + data_trial_indices = set(cumulative_value.index) + expanded_trace = [] + compact_idx = 0 + last_best_value = -float("inf") if maximize else float("inf") + + for trial_index in sorted(experiment.trials.keys()): + trial = experiment.trials[trial_index] + if trial_index in data_trial_indices: + # Trial has data in compact trace + if compact_idx < len(compact_trace): + value = compact_trace[compact_idx] + expanded_trace.append(value) + last_best_value = value + compact_idx += 1 + else: + # Should not happen, but handle gracefully + expanded_trace.append(last_best_value) + elif trial.status in (TrialStatus.ABANDONED, TrialStatus.FAILED): + # Trial has no data; carry forward the last best value. + expanded_trace.append(last_best_value) + + return expanded_trace def get_tensor_converter_adapter(