Skip to content

Commit 30ce28a

Browse files
Sunny Shenfacebook-github-bot
authored andcommitted
Allow Oracle Experiment to take ABANDONED trials into account (facebook#4953)
Summary: Include ABANDONED trials in the trace by carrying forward the last best value. This ensures the trace has one value per trial, reflecting that ABANDONED trials consumed resources but didn't improve optimization. Differential Revision: D86833965
1 parent 8b0773e commit 30ce28a

File tree

4 files changed

+166
-3
lines changed

4 files changed

+166
-3
lines changed

ax/benchmark/benchmark.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from ax.core.search_space import SearchSpace
5050
from ax.core.trial import BaseTrial, Trial
51+
from ax.core.trial_status import TrialStatus
5152
from ax.core.types import TParameterization, TParamValue
5253
from ax.core.utils import get_model_times
5354
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
@@ -157,6 +158,7 @@ def get_benchmark_runner(
157158
def get_oracle_experiment_from_params(
158159
problem: BenchmarkProblem,
159160
dict_of_dict_of_params: Mapping[int, Mapping[str, Mapping[str, TParamValue]]],
161+
trial_statuses: Mapping[int, TrialStatus] | None = None,
160162
) -> Experiment:
161163
"""
162164
Get a new experiment with the same search space and optimization config
@@ -170,6 +172,12 @@ def get_oracle_experiment_from_params(
170172
config for generating an experiment.
171173
dict_of_dict_of_params: Keys are trial indices, values are Mappings
172174
(e.g. dicts) that map arm names to parameterizations.
175+
trial_statuses: Optional mapping from trial indices to their statuses.
176+
If provided, trials in oracle experiments will be set to the
177+
specified status.
178+
This helps preserve the trial status from the original experiment,
179+
especially if we want to take `ABANDONED` trials into account.
180+
If not provided, trials will be set to completed.
173181
174182
Example:
175183
>>> get_oracle_experiment_from_params(
@@ -215,11 +223,33 @@ def get_oracle_experiment_from_params(
215223
trial = experiment.trials[trial_index]
216224
metadata = runner.run(trial=trial)
217225
trial.update_run_metadata(metadata=metadata)
218-
trial.mark_completed()
226+
227+
# Determine the status for the trial in the oracle experiment.
228+
# Mark ABANDONED and FAILED immediately (they don't require data).
229+
# EARLY_STOPPED requires data, so mark as completed for now and
230+
# defer the status change until after fetch_data().
231+
if trial_statuses is not None:
232+
status = trial_statuses[trial_index]
233+
else:
234+
status = TrialStatus.COMPLETED
235+
236+
if status == TrialStatus.ABANDONED:
237+
trial.mark_abandoned()
238+
elif status == TrialStatus.FAILED:
239+
trial.mark_failed()
240+
else:
241+
trial.mark_completed()
219242

220243
logger.setLevel(level=original_log_level)
221244

222245
experiment.fetch_data()
246+
247+
# Apply EARLY_STOPPED status after data is available, since
248+
# mark_early_stopped() requires data on the trial.
249+
if trial_statuses is not None:
250+
for trial_index, status in trial_statuses.items():
251+
if status == TrialStatus.EARLY_STOPPED:
252+
experiment.trials[trial_index].mark_early_stopped(unsafe=True)
223253
return experiment
224254

225255

@@ -451,13 +481,22 @@ def get_benchmark_result_from_experiment_and_gs(
451481
for new_trial_index, trials in enumerate(trial_completion_order)
452482
}
453483

484+
# Create trial_statuses mapping to preserve trial status in oracle experiment
485+
trial_statuses = {
486+
trial_index: experiment.trials[trial_index].status
487+
for trial_index in dict_of_dict_of_params.keys()
488+
}
489+
454490
actual_params_oracle_dummy_experiment = get_oracle_experiment_from_params(
455-
problem=problem, dict_of_dict_of_params=dict_of_dict_of_params
491+
problem=problem,
492+
dict_of_dict_of_params=dict_of_dict_of_params,
493+
trial_statuses=trial_statuses,
456494
)
457495
oracle_trace = np.array(
458496
get_trace(
459497
experiment=actual_params_oracle_dummy_experiment,
460498
optimization_config=problem.optimization_config,
499+
include_abandoned=True,
461500
)
462501
)
463502
is_feasible_trace = np.array(

ax/benchmark/tests/test_benchmark.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
get_single_objective_benchmark_problem,
6969
get_soo_surrogate,
7070
)
71+
from ax.core.base_trial import TrialStatus
7172
from ax.core.experiment import Experiment
7273
from ax.core.objective import MultiObjective
7374
from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy
@@ -1014,6 +1015,52 @@ def test_get_oracle_experiment_from_params(self) -> None:
10141015
problem=problem, dict_of_dict_of_params={0: {}}
10151016
)
10161017

1018+
with self.subTest("trial_statuses"):
1019+
trial_statuses = {
1020+
0: TrialStatus.COMPLETED,
1021+
1: TrialStatus.ABANDONED,
1022+
}
1023+
experiment = get_oracle_experiment_from_params(
1024+
problem=problem,
1025+
dict_of_dict_of_params={
1026+
0: {"0": near_opt_params},
1027+
1: {"1": other_params},
1028+
},
1029+
trial_statuses=trial_statuses,
1030+
)
1031+
self.assertEqual(len(experiment.trials), 2)
1032+
self.assertTrue(experiment.trials[0].status.is_completed)
1033+
self.assertEqual(experiment.trials[1].status, TrialStatus.ABANDONED)
1034+
1035+
with self.subTest("trial_statuses with FAILED and EARLY_STOPPED"):
1036+
trial_statuses = {
1037+
0: TrialStatus.FAILED,
1038+
1: TrialStatus.EARLY_STOPPED,
1039+
}
1040+
experiment = get_oracle_experiment_from_params(
1041+
problem=problem,
1042+
dict_of_dict_of_params={
1043+
0: {"0": near_opt_params},
1044+
1: {"1": other_params},
1045+
},
1046+
trial_statuses=trial_statuses,
1047+
)
1048+
self.assertEqual(experiment.trials[0].status, TrialStatus.FAILED)
1049+
self.assertEqual(experiment.trials[1].status, TrialStatus.EARLY_STOPPED)
1050+
1051+
with self.subTest("trial_statuses=None defaults to COMPLETED"):
1052+
experiment = get_oracle_experiment_from_params(
1053+
problem=problem,
1054+
dict_of_dict_of_params={
1055+
0: {"0": near_opt_params},
1056+
1: {"1": other_params},
1057+
},
1058+
trial_statuses=None,
1059+
)
1060+
self.assertTrue(
1061+
all(t.status.is_completed for t in experiment.trials.values())
1062+
)
1063+
10171064
def _test_multi_fidelity_or_multi_task(
10181065
self, fidelity_or_task: Literal["fidelity", "task"]
10191066
) -> None:

ax/service/tests/test_best_point.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,50 @@ def test_get_trace(self) -> None:
183183
exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict2)))
184184
self.assertEqual(get_trace(exp), [2.0, 20.0])
185185

186+
def test_get_trace_include_abandoned(self) -> None:
187+
with self.subTest("minimize with abandoned trial"):
188+
exp = get_experiment_with_observations(
189+
observations=[[11], [10], [9], [15], [5]], minimize=True
190+
)
191+
# Mark trial 2 (value=9) as abandoned
192+
exp.trials[2].mark_abandoned(unsafe=True)
193+
194+
# Without include_abandoned (default): abandoned trial excluded
195+
trace_default = get_trace(exp)
196+
self.assertEqual(trace_default, [11, 10, 10, 5])
197+
198+
# With include_abandoned=True: abandoned trial carries forward
199+
trace_with_abandoned = get_trace(exp, include_abandoned=True)
200+
self.assertEqual(len(trace_with_abandoned), 5)
201+
# Trial 0: 11, Trial 1: 10, Trial 2 (abandoned): carry forward 10
202+
self.assertEqual(trace_with_abandoned, [11, 10, 10, 10, 5])
203+
204+
with self.subTest("maximize with abandoned trial"):
205+
exp = get_experiment_with_observations(
206+
observations=[[1], [3], [2], [5], [4]], minimize=False
207+
)
208+
# Mark trial 1 (value=3) as abandoned
209+
exp.trials[1].mark_abandoned(unsafe=True)
210+
211+
# Without include_abandoned: only 4 values
212+
trace_default = get_trace(exp)
213+
self.assertEqual(trace_default, [1, 2, 5, 5])
214+
215+
# With include_abandoned: 5 values, carry forward
216+
trace_with_abandoned = get_trace(exp, include_abandoned=True)
217+
self.assertEqual(len(trace_with_abandoned), 5)
218+
# Trial 0: 1, Trial 1 (abandoned): carry forward 1,
219+
# Trial 2: 2, Trial 3: 5, Trial 4: 5
220+
self.assertEqual(trace_with_abandoned, [1, 1, 2, 5, 5])
221+
222+
with self.subTest("include_abandoned=False is default"):
223+
exp = get_experiment_with_observations(
224+
observations=[[11], [10], [9]], minimize=True
225+
)
226+
trace_explicit = get_trace(exp, include_abandoned=False)
227+
trace_default = get_trace(exp)
228+
self.assertEqual(trace_explicit, trace_default)
229+
186230
def test_get_trace_with_include_status_quo(self) -> None:
187231
with self.subTest("Multi-objective: status quo dominates in some trials"):
188232
# Create experiment with multi-objective optimization where status quo

ax/service/utils/best_point.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,7 @@ def get_trace_by_arm_pull_from_data(
10411041
def get_trace(
10421042
experiment: Experiment,
10431043
optimization_config: OptimizationConfig | None = None,
1044+
include_abandoned: bool = False,
10441045
include_status_quo: bool = False,
10451046
) -> list[float]:
10461047
"""Compute the optimization trace at each iteration.
@@ -1069,6 +1070,11 @@ def get_trace(
10691070
include_status_quo: If True, include status quo in the trace computation.
10701071
If False (default), exclude status quo for compatibility with legacy
10711072
behavior.
1073+
include_abandoned: If True, include ABANDONED trials in the trace by
1074+
carrying forward the last best value. This ensures the trace has
1075+
one value per trial, reflecting that ABANDONED trials consumed
1076+
resources but didn't improve optimization. If False (default),
1077+
only COMPLETED and EARLY_STOPPED trials are included.
10721078
10731079
Returns:
10741080
A list of performance values at each iteration.
@@ -1128,7 +1134,34 @@ def get_trace(
11281134
value_by_trial = trial_grouped.min()
11291135
cumulative_value = np.minimum.accumulate(value_by_trial)
11301136

1131-
return cumulative_value.tolist()
1137+
compact_trace = cumulative_value.tolist()
1138+
1139+
# If not including abandoned trials, return early
1140+
if not include_abandoned:
1141+
return compact_trace
1142+
1143+
# Expand trace to include ABANDONED trials with carry-forward values
1144+
expanded_trace = []
1145+
compact_idx = 0
1146+
last_best_value = -float("inf") if maximize else float("inf")
1147+
1148+
for trial_index in sorted(experiment.trials.keys()):
1149+
trial = experiment.trials[trial_index]
1150+
if trial.status in (TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED):
1151+
# Use value from compact trace
1152+
if compact_idx < len(compact_trace):
1153+
value = compact_trace[compact_idx]
1154+
expanded_trace.append(value)
1155+
last_best_value = value
1156+
compact_idx += 1
1157+
else:
1158+
# Should not happen, but handle gracefully
1159+
expanded_trace.append(last_best_value)
1160+
else:
1161+
# ABANDONED or other status: carry forward last best value
1162+
expanded_trace.append(last_best_value)
1163+
1164+
return expanded_trace
11321165

11331166

11341167
def get_tensor_converter_adapter(

0 commit comments

Comments
 (0)