Skip to content

Commit 5c125ba

Browse files
Sunny Shenfacebook-github-bot
authored andcommitted
Allow Oracle Experiment to take ABANDONED trials into account (facebook#4953)
Summary: Include ABANDONED trials in the trace by carrying forward the last best value. This ensures the trace has one value per trial, reflecting that ABANDONED trials consumed resources but didn't improve optimization. Reviewed By: saitcakmak Differential Revision: D86833965
1 parent 1054802 commit 5c125ba

4 files changed

Lines changed: 174 additions & 8 deletions

File tree

ax/benchmark/benchmark.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from ax.core.search_space import SearchSpace
5050
from ax.core.trial import BaseTrial, Trial
51+
from ax.core.trial_status import TrialStatus
5152
from ax.core.types import TParameterization, TParamValue
5253
from ax.core.utils import get_model_times
5354
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
@@ -161,6 +162,7 @@ def get_benchmark_runner(
161162
def get_oracle_experiment_from_params(
162163
problem: BenchmarkProblem,
163164
dict_of_dict_of_params: Mapping[int, Mapping[str, Mapping[str, TParamValue]]],
165+
trial_statuses: Mapping[int, TrialStatus] | None = None,
164166
) -> Experiment:
165167
"""
166168
Get a new experiment with the same search space and optimization config
@@ -174,6 +176,12 @@ def get_oracle_experiment_from_params(
174176
config for generating an experiment.
175177
dict_of_dict_of_params: Keys are trial indices, values are Mappings
176178
(e.g. dicts) that map arm names to parameterizations.
179+
trial_statuses: Optional mapping from trial indices to their statuses.
180+
If provided, trials in oracle experiments will be set to the
181+
specified status.
182+
This helps preserve the trial status from the original experiment,
183+
especially if we want to take `ABANDONED` trials into account.
184+
If not provided, trials will be set to completed.
177185
178186
Example:
179187
>>> get_oracle_experiment_from_params(
@@ -219,11 +227,33 @@ def get_oracle_experiment_from_params(
219227
trial = experiment.trials[trial_index]
220228
metadata = runner.run(trial=trial)
221229
trial.update_run_metadata(metadata=metadata)
222-
trial.mark_completed()
230+
231+
# Determine the status for the trial in the oracle experiment.
232+
# Mark ABANDONED and FAILED immediately (they don't require data).
233+
# EARLY_STOPPED requires data, so mark as completed for now and
234+
# defer the status change until after fetch_data().
235+
if trial_statuses is not None:
236+
status = trial_statuses[trial_index]
237+
else:
238+
status = TrialStatus.COMPLETED
239+
240+
if status == TrialStatus.ABANDONED:
241+
trial.mark_abandoned()
242+
elif status == TrialStatus.FAILED:
243+
trial.mark_failed()
244+
else:
245+
trial.mark_completed()
223246

224247
logger.setLevel(level=original_log_level)
225248

226249
experiment.fetch_data()
250+
251+
# Apply EARLY_STOPPED status after data is available, since
252+
# mark_early_stopped() requires data on the trial.
253+
if trial_statuses is not None:
254+
for trial_index, status in trial_statuses.items():
255+
if status == TrialStatus.EARLY_STOPPED:
256+
experiment.trials[trial_index].mark_early_stopped(unsafe=True)
227257
return experiment
228258

229259

@@ -342,14 +372,15 @@ def get_inference_trace(
342372

343373
def get_is_feasible_trace(
344374
experiment: Experiment, optimization_config: OptimizationConfig
345-
) -> list[float]:
375+
) -> list[bool]:
346376
"""Get a trace of feasibility for the experiment.
347377
348378
For batch trials we return True if any arm in a given batch is feasible.
379+
Trials without data (e.g. abandoned or failed) default to False.
349380
"""
350381
df = experiment.lookup_data().df.copy() # Let's not modify the original df
351382
if len(df) == 0:
352-
return []
383+
return [False] * len(experiment.trials)
353384
# Derelativize the optimization config if needed.
354385
optimization_config = derelativize_opt_config(
355386
optimization_config=optimization_config,
@@ -358,7 +389,11 @@ def get_is_feasible_trace(
358389
# Compute feasibility and return feasibility per group
359390
df = _prepare_data_for_trace(df=df, optimization_config=optimization_config)
360391
trial_grouped = df.groupby("trial_index")["feasible"]
361-
return trial_grouped.any().tolist()
392+
feasibility_by_trial = trial_grouped.any().to_dict()
393+
return [
394+
feasibility_by_trial.get(trial_index, False)
395+
for trial_index in sorted(experiment.trials.keys())
396+
]
362397

363398

364399
def get_best_parameters(
@@ -455,8 +490,20 @@ def get_benchmark_result_from_experiment_and_gs(
455490
for new_trial_index, trials in enumerate(trial_completion_order)
456491
}
457492

493+
# Create trial_statuses mapping to preserve trial status in oracle experiment.
494+
# If all trials in a completion group share the same status, use that status;
495+
# otherwise default to COMPLETED.
496+
trial_statuses = {}
497+
for new_trial_index, old_trial_indices in enumerate(trial_completion_order):
498+
statuses = {experiment.trials[idx].status for idx in old_trial_indices}
499+
trial_statuses[new_trial_index] = (
500+
next(iter(statuses)) if len(statuses) == 1 else TrialStatus.COMPLETED
501+
)
502+
458503
actual_params_oracle_dummy_experiment = get_oracle_experiment_from_params(
459-
problem=problem, dict_of_dict_of_params=dict_of_dict_of_params
504+
problem=problem,
505+
dict_of_dict_of_params=dict_of_dict_of_params,
506+
trial_statuses=trial_statuses,
460507
)
461508
oracle_trace = np.array(
462509
get_trace(

ax/benchmark/tests/test_benchmark.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
)
7171
from ax.core.experiment import Experiment
7272
from ax.core.objective import MultiObjective
73+
from ax.core.trial_status import TrialStatus
7374
from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy
7475
from ax.generation_strategy.external_generation_node import ExternalGenerationNode
7576
from ax.generation_strategy.generation_strategy import (
@@ -1014,6 +1015,52 @@ def test_get_oracle_experiment_from_params(self) -> None:
10141015
problem=problem, dict_of_dict_of_params={0: {}}
10151016
)
10161017

1018+
with self.subTest("trial_statuses"):
1019+
trial_statuses = {
1020+
0: TrialStatus.COMPLETED,
1021+
1: TrialStatus.ABANDONED,
1022+
}
1023+
experiment = get_oracle_experiment_from_params(
1024+
problem=problem,
1025+
dict_of_dict_of_params={
1026+
0: {"0": near_opt_params},
1027+
1: {"1": other_params},
1028+
},
1029+
trial_statuses=trial_statuses,
1030+
)
1031+
self.assertEqual(len(experiment.trials), 2)
1032+
self.assertTrue(experiment.trials[0].status.is_completed)
1033+
self.assertEqual(experiment.trials[1].status, TrialStatus.ABANDONED)
1034+
1035+
with self.subTest("trial_statuses with FAILED and EARLY_STOPPED"):
1036+
trial_statuses = {
1037+
0: TrialStatus.FAILED,
1038+
1: TrialStatus.EARLY_STOPPED,
1039+
}
1040+
experiment = get_oracle_experiment_from_params(
1041+
problem=problem,
1042+
dict_of_dict_of_params={
1043+
0: {"0": near_opt_params},
1044+
1: {"1": other_params},
1045+
},
1046+
trial_statuses=trial_statuses,
1047+
)
1048+
self.assertEqual(experiment.trials[0].status, TrialStatus.FAILED)
1049+
self.assertEqual(experiment.trials[1].status, TrialStatus.EARLY_STOPPED)
1050+
1051+
with self.subTest("trial_statuses=None defaults to COMPLETED"):
1052+
experiment = get_oracle_experiment_from_params(
1053+
problem=problem,
1054+
dict_of_dict_of_params={
1055+
0: {"0": near_opt_params},
1056+
1: {"1": other_params},
1057+
},
1058+
trial_statuses=None,
1059+
)
1060+
self.assertTrue(
1061+
all(t.status.is_completed for t in experiment.trials.values())
1062+
)
1063+
10171064
def _test_multi_fidelity_or_multi_task(
10181065
self, fidelity_or_task: Literal["fidelity", "task"]
10191066
) -> None:

ax/service/tests/test_best_point.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,50 @@ def test_get_trace(self) -> None:
189189
]
190190
)
191191
exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict2)))
192-
self.assertEqual(get_trace(exp), [2.0, 20.0])
192+
self.assertEqual(get_trace(exp), [2.0, 2.0, 20.0])
193+
194+
def test_get_trace_with_non_completed_trials(self) -> None:
195+
with self.subTest("minimize with abandoned trial"):
196+
exp = get_experiment_with_observations(
197+
observations=[[11], [10], [9], [15], [5]], minimize=True
198+
)
199+
# Mark trial 2 (value=9) as abandoned
200+
exp.trials[2].mark_abandoned(unsafe=True)
201+
202+
# Abandoned trial carries forward the last best value
203+
trace = get_trace(exp)
204+
self.assertEqual(len(trace), 5)
205+
# Trial 0: 11, Trial 1: 10, Trial 2 (abandoned): carry forward 10
206+
# Trial 3: 10 (15 > 10), Trial 4: 5
207+
self.assertEqual(trace, [11, 10, 10, 10, 5])
208+
209+
with self.subTest("maximize with abandoned trial"):
210+
exp = get_experiment_with_observations(
211+
observations=[[1], [3], [2], [5], [4]], minimize=False
212+
)
213+
# Mark trial 1 (value=3) as abandoned
214+
exp.trials[1].mark_abandoned(unsafe=True)
215+
216+
# Abandoned trial carries forward the last best value
217+
trace = get_trace(exp)
218+
self.assertEqual(len(trace), 5)
219+
# Trial 0: 1, Trial 1 (abandoned): carry forward 1,
220+
# Trial 2: 2, Trial 3: 5, Trial 4: 5
221+
self.assertEqual(trace, [1, 1, 2, 5, 5])
222+
223+
with self.subTest("minimize with failed trial"):
224+
exp = get_experiment_with_observations(
225+
observations=[[11], [10], [9], [15], [5]], minimize=True
226+
)
227+
# Mark trial 2 (value=9) as failed
228+
exp.trials[2].mark_failed(unsafe=True)
229+
230+
# Failed trial carries forward the last best value
231+
trace = get_trace(exp)
232+
self.assertEqual(len(trace), 5)
233+
# Trial 0: 11, Trial 1: 10, Trial 2 (failed): carry forward 10
234+
# Trial 3: 10 (15 > 10), Trial 4: 5
235+
self.assertEqual(trace, [11, 10, 10, 10, 5])
193236

194237
def test_get_trace_with_include_status_quo(self) -> None:
195238
with self.subTest("Multi-objective: status quo dominates in some trials"):

ax/service/utils/best_point.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,8 @@ def get_trace(
12191219
12201220
An iteration here refers to a completed or early-stopped (batch) trial.
12211221
There will be one performance metric in the trace for each iteration.
1222+
Trials without data (e.g. abandoned or failed) carry forward the last
1223+
best value.
12221224
12231225
Args:
12241226
experiment: The experiment to get the trace for.
@@ -1278,12 +1280,39 @@ def get_trace(
12781280
# Aggregate by trial, then. compute cumulative best
12791281
objective = optimization_config.objective
12801282
maximize = isinstance(objective, MultiObjective) or not objective.minimize
1281-
return _aggregate_and_cumulate_trace(
1283+
cumulative_value = _aggregate_and_cumulate_trace(
12821284
df=value_by_arm_pull,
12831285
by=["trial_index"],
12841286
maximize=maximize,
12851287
keep_order=False, # sort by trial index
1286-
).tolist()
1288+
)
1289+
1290+
compact_trace = cumulative_value.tolist()
1291+
1292+
# Expand trace to include trials without data (e.g. ABANDONED, FAILED)
1293+
# with carry-forward values.
1294+
data_trial_indices = set(cumulative_value.index)
1295+
expanded_trace = []
1296+
compact_idx = 0
1297+
last_best_value = -float("inf") if maximize else float("inf")
1298+
1299+
for trial_index in sorted(experiment.trials.keys()):
1300+
trial = experiment.trials[trial_index]
1301+
if trial_index in data_trial_indices:
1302+
# Trial has data in compact trace
1303+
if compact_idx < len(compact_trace):
1304+
value = compact_trace[compact_idx]
1305+
expanded_trace.append(value)
1306+
last_best_value = value
1307+
compact_idx += 1
1308+
else:
1309+
# Should not happen, but handle gracefully
1310+
expanded_trace.append(last_best_value)
1311+
elif trial.status in (TrialStatus.ABANDONED, TrialStatus.FAILED):
1312+
# Trial has no data; carry forward the last best value.
1313+
expanded_trace.append(last_best_value)
1314+
1315+
return expanded_trace
12871316

12881317

12891318
def get_tensor_converter_adapter(

0 commit comments

Comments
 (0)