Skip to content

Commit b304ab3

Browse files
Sunny Shenmeta-codesync[bot]
authored andcommitted
Allow Oracle Experiment to take ABANDONED trials into account (facebook#4953)
Summary: Pull Request resolved: facebook#4953 Include ABANDONED trials in the trace by carrying forward the last best value. This ensures the trace has one value per trial, reflecting that ABANDONED trials consumed resources but didn't improve optimization. Differential Revision: D86833965
1 parent 9f50a93 commit b304ab3

File tree

4 files changed

+157
-8
lines changed

4 files changed

+157
-8
lines changed

ax/benchmark/benchmark.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from ax.core.search_space import SearchSpace
5050
from ax.core.trial import BaseTrial, Trial
51+
from ax.core.trial_status import TrialStatus
5152
from ax.core.types import TParameterization, TParamValue
5253
from ax.core.utils import get_model_times
5354
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
@@ -161,6 +162,7 @@ def get_benchmark_runner(
161162
def get_oracle_experiment_from_params(
162163
problem: BenchmarkProblem,
163164
dict_of_dict_of_params: Mapping[int, Mapping[str, Mapping[str, TParamValue]]],
165+
trial_statuses: Mapping[int, TrialStatus] | None = None,
164166
) -> Experiment:
165167
"""
166168
Get a new experiment with the same search space and optimization config
@@ -174,6 +176,12 @@ def get_oracle_experiment_from_params(
174176
config for generating an experiment.
175177
dict_of_dict_of_params: Keys are trial indices, values are Mappings
176178
(e.g. dicts) that map arm names to parameterizations.
179+
trial_statuses: Optional mapping from trial indices to their statuses.
180+
If provided, trials in oracle experiments will be set to the
181+
specified status.
182+
This helps preserve the trial status from the original experiment,
183+
especially if we want to take `ABANDONED` trials into account.
184+
If not provided, trials will be set to completed.
177185
178186
Example:
179187
>>> get_oracle_experiment_from_params(
@@ -219,11 +227,33 @@ def get_oracle_experiment_from_params(
219227
trial = experiment.trials[trial_index]
220228
metadata = runner.run(trial=trial)
221229
trial.update_run_metadata(metadata=metadata)
222-
trial.mark_completed()
230+
231+
# Determine the status for the trial in the oracle experiment.
232+
# Mark ABANDONED and FAILED immediately (they don't require data).
233+
# EARLY_STOPPED requires data, so mark as completed for now and
234+
# defer the status change until after fetch_data().
235+
if trial_statuses is not None:
236+
status = trial_statuses[trial_index]
237+
else:
238+
status = TrialStatus.COMPLETED
239+
240+
if status == TrialStatus.ABANDONED:
241+
trial.mark_abandoned()
242+
elif status == TrialStatus.FAILED:
243+
trial.mark_failed()
244+
else:
245+
trial.mark_completed()
223246

224247
logger.setLevel(level=original_log_level)
225248

226249
experiment.fetch_data()
250+
251+
# Apply EARLY_STOPPED status after data is available, since
252+
# mark_early_stopped() requires data on the trial.
253+
if trial_statuses is not None:
254+
for trial_index, status in trial_statuses.items():
255+
if status == TrialStatus.EARLY_STOPPED:
256+
experiment.trials[trial_index].mark_early_stopped(unsafe=True)
227257
return experiment
228258

229259

@@ -342,14 +372,15 @@ def get_inference_trace(
342372

343373
def get_is_feasible_trace(
344374
experiment: Experiment, optimization_config: OptimizationConfig
345-
) -> list[float]:
375+
) -> list[bool]:
346376
"""Get a trace of feasibility for the experiment.
347377
348378
For batch trials we return True if any arm in a given batch is feasible.
379+
Trials without data (e.g. abandoned or failed) default to False.
349380
"""
350381
df = experiment.lookup_data().df.copy() # Let's not modify the original df
351382
if len(df) == 0:
352-
return []
383+
return [False] * len(experiment.trials)
353384
# Derelativize the optimization config if needed.
354385
optimization_config = derelativize_opt_config(
355386
optimization_config=optimization_config,
@@ -358,7 +389,11 @@ def get_is_feasible_trace(
358389
# Compute feasibility and return feasibility per group
359390
df = _prepare_data_for_trace(df=df, optimization_config=optimization_config)
360391
trial_grouped = df.groupby("trial_index")["feasible"]
361-
return trial_grouped.any().tolist()
392+
feasibility_by_trial = trial_grouped.any().to_dict()
393+
return [
394+
feasibility_by_trial.get(trial_index, False)
395+
for trial_index in sorted(experiment.trials.keys())
396+
]
362397

363398

364399
def get_best_parameters(
@@ -455,8 +490,20 @@ def get_benchmark_result_from_experiment_and_gs(
455490
for new_trial_index, trials in enumerate(trial_completion_order)
456491
}
457492

493+
# Create trial_statuses mapping to preserve trial status in oracle experiment
494+
trial_statuses = {
495+
new_trial_index: (
496+
experiment.trials[next(iter(old_trial_indices))].status
497+
if len(old_trial_indices) == 1
498+
else TrialStatus.COMPLETED
499+
)
500+
for new_trial_index, old_trial_indices in enumerate(trial_completion_order)
501+
}
502+
458503
actual_params_oracle_dummy_experiment = get_oracle_experiment_from_params(
459-
problem=problem, dict_of_dict_of_params=dict_of_dict_of_params
504+
problem=problem,
505+
dict_of_dict_of_params=dict_of_dict_of_params,
506+
trial_statuses=trial_statuses,
460507
)
461508
oracle_trace = np.array(
462509
get_trace(

ax/benchmark/tests/test_benchmark.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
)
7171
from ax.core.experiment import Experiment
7272
from ax.core.objective import MultiObjective
73+
from ax.core.trial_status import TrialStatus
7374
from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy
7475
from ax.generation_strategy.external_generation_node import ExternalGenerationNode
7576
from ax.generation_strategy.generation_strategy import (
@@ -1014,6 +1015,52 @@ def test_get_oracle_experiment_from_params(self) -> None:
10141015
problem=problem, dict_of_dict_of_params={0: {}}
10151016
)
10161017

1018+
with self.subTest("trial_statuses"):
1019+
trial_statuses = {
1020+
0: TrialStatus.COMPLETED,
1021+
1: TrialStatus.ABANDONED,
1022+
}
1023+
experiment = get_oracle_experiment_from_params(
1024+
problem=problem,
1025+
dict_of_dict_of_params={
1026+
0: {"0": near_opt_params},
1027+
1: {"1": other_params},
1028+
},
1029+
trial_statuses=trial_statuses,
1030+
)
1031+
self.assertEqual(len(experiment.trials), 2)
1032+
self.assertTrue(experiment.trials[0].status.is_completed)
1033+
self.assertEqual(experiment.trials[1].status, TrialStatus.ABANDONED)
1034+
1035+
with self.subTest("trial_statuses with FAILED and EARLY_STOPPED"):
1036+
trial_statuses = {
1037+
0: TrialStatus.FAILED,
1038+
1: TrialStatus.EARLY_STOPPED,
1039+
}
1040+
experiment = get_oracle_experiment_from_params(
1041+
problem=problem,
1042+
dict_of_dict_of_params={
1043+
0: {"0": near_opt_params},
1044+
1: {"1": other_params},
1045+
},
1046+
trial_statuses=trial_statuses,
1047+
)
1048+
self.assertEqual(experiment.trials[0].status, TrialStatus.FAILED)
1049+
self.assertEqual(experiment.trials[1].status, TrialStatus.EARLY_STOPPED)
1050+
1051+
with self.subTest("trial_statuses=None defaults to COMPLETED"):
1052+
experiment = get_oracle_experiment_from_params(
1053+
problem=problem,
1054+
dict_of_dict_of_params={
1055+
0: {"0": near_opt_params},
1056+
1: {"1": other_params},
1057+
},
1058+
trial_statuses=None,
1059+
)
1060+
self.assertTrue(
1061+
all(t.status.is_completed for t in experiment.trials.values())
1062+
)
1063+
10171064
def _test_multi_fidelity_or_multi_task(
10181065
self, fidelity_or_task: Literal["fidelity", "task"]
10191066
) -> None:

ax/service/tests/test_best_point.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,36 @@ def test_get_trace(self) -> None:
189189
]
190190
)
191191
exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict2)))
192-
self.assertEqual(get_trace(exp), [2.0, 20.0])
192+
self.assertEqual(get_trace(exp), [2.0, 2.0, 20.0])
193+
194+
def test_get_trace_with_abandoned_trials(self) -> None:
195+
with self.subTest("minimize with abandoned trial"):
196+
exp = get_experiment_with_observations(
197+
observations=[[11], [10], [9], [15], [5]], minimize=True
198+
)
199+
# Mark trial 2 (value=9) as abandoned
200+
exp.trials[2].mark_abandoned(unsafe=True)
201+
202+
# Abandoned trial carries forward the last best value
203+
trace = get_trace(exp)
204+
self.assertEqual(len(trace), 5)
205+
# Trial 0: 11, Trial 1: 10, Trial 2 (abandoned): carry forward 10
206+
# Trial 3: 10 (15 > 10), Trial 4: 5
207+
self.assertEqual(trace, [11, 10, 10, 10, 5])
208+
209+
with self.subTest("maximize with abandoned trial"):
210+
exp = get_experiment_with_observations(
211+
observations=[[1], [3], [2], [5], [4]], minimize=False
212+
)
213+
# Mark trial 1 (value=3) as abandoned
214+
exp.trials[1].mark_abandoned(unsafe=True)
215+
216+
# Abandoned trial carries forward the last best value
217+
trace = get_trace(exp)
218+
self.assertEqual(len(trace), 5)
219+
# Trial 0: 1, Trial 1 (abandoned): carry forward 1,
220+
# Trial 2: 2, Trial 3: 5, Trial 4: 5
221+
self.assertEqual(trace, [1, 1, 2, 5, 5])
193222

194223
def test_get_trace_with_include_status_quo(self) -> None:
195224
with self.subTest("Multi-objective: status quo dominates in some trials"):

ax/service/utils/best_point.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,7 @@ def get_trace(
12191219
12201220
An iteration here refers to a completed or early-stopped (batch) trial.
12211221
There will be one performance metric in the trace for each iteration.
1222+
Abandoned trials carry forward the last best value.
12221223
12231224
Args:
12241225
experiment: The experiment to get the trace for.
@@ -1278,12 +1279,37 @@ def get_trace(
12781279
# Aggregate by trial, then. compute cumulative best
12791280
objective = optimization_config.objective
12801281
maximize = isinstance(objective, MultiObjective) or not objective.minimize
1281-
return _aggregate_and_cumulate_trace(
1282+
cumulative_value = _aggregate_and_cumulate_trace(
12821283
df=value_by_arm_pull,
12831284
by=["trial_index"],
12841285
maximize=maximize,
12851286
keep_order=False, # sort by trial index
1286-
).tolist()
1287+
)
1288+
1289+
compact_trace = cumulative_value.tolist()
1290+
1291+
# Expand trace to include ABANDONED trials with carry-forward values.
1292+
data_trial_indices = set(cumulative_value.index)
1293+
expanded_trace = []
1294+
compact_idx = 0
1295+
last_best_value = -float("inf") if maximize else float("inf")
1296+
1297+
for trial_index in sorted(experiment.trials.keys()):
1298+
trial = experiment.trials[trial_index]
1299+
if trial_index in data_trial_indices:
1300+
# Trial has data in compact trace
1301+
if compact_idx < len(compact_trace):
1302+
value = compact_trace[compact_idx]
1303+
expanded_trace.append(value)
1304+
last_best_value = value
1305+
compact_idx += 1
1306+
else:
1307+
# Should not happen, but handle gracefully
1308+
expanded_trace.append(last_best_value)
1309+
elif trial.status == TrialStatus.ABANDONED:
1310+
expanded_trace.append(last_best_value)
1311+
1312+
return expanded_trace
12871313

12881314

12891315
def get_tensor_converter_adapter(

0 commit comments

Comments
 (0)