Skip to content

Commit 2ec71c2

Browse files
Sunny Shenfacebook-github-bot
authored andcommitted
Allow Oracle Experiment to take ABANDONED trials into account (facebook#4953)
Summary: Include ABANDONED trials in the trace by carrying forward the last best value. This ensures the trace has one value per trial, reflecting that ABANDONED trials consumed resources but didn't improve optimization. Differential Revision: D86833965
1 parent d7e7620 commit 2ec71c2

File tree

4 files changed

+154
-7
lines changed

4 files changed

+154
-7
lines changed

ax/benchmark/benchmark.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from ax.core.search_space import SearchSpace
5050
from ax.core.trial import BaseTrial, Trial
51+
from ax.core.trial_status import TrialStatus
5152
from ax.core.types import TParameterization, TParamValue
5253
from ax.core.utils import get_model_times
5354
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
@@ -157,6 +158,7 @@ def get_benchmark_runner(
157158
def get_oracle_experiment_from_params(
158159
problem: BenchmarkProblem,
159160
dict_of_dict_of_params: Mapping[int, Mapping[str, Mapping[str, TParamValue]]],
161+
trial_statuses: Mapping[int, TrialStatus] | None = None,
160162
) -> Experiment:
161163
"""
162164
Get a new experiment with the same search space and optimization config
@@ -170,6 +172,12 @@ def get_oracle_experiment_from_params(
170172
config for generating an experiment.
171173
dict_of_dict_of_params: Keys are trial indices, values are Mappings
172174
(e.g. dicts) that map arm names to parameterizations.
175+
trial_statuses: Optional mapping from trial indices to their statuses.
176+
If provided, trials in oracle experiments will be set to the
177+
specified status.
178+
This helps preserve the trial status from the original experiment,
179+
especially if we want to take `ABANDONED` trials into account.
180+
If not provided, trials will be set to completed.
173181
174182
Example:
175183
>>> get_oracle_experiment_from_params(
@@ -215,11 +223,33 @@ def get_oracle_experiment_from_params(
215223
trial = experiment.trials[trial_index]
216224
metadata = runner.run(trial=trial)
217225
trial.update_run_metadata(metadata=metadata)
218-
trial.mark_completed()
226+
227+
# Determine the status for the trial in the oracle experiment.
228+
# Mark ABANDONED and FAILED immediately (they don't require data).
229+
# EARLY_STOPPED requires data, so mark as completed for now and
230+
# defer the status change until after fetch_data().
231+
if trial_statuses is not None:
232+
status = trial_statuses[trial_index]
233+
else:
234+
status = TrialStatus.COMPLETED
235+
236+
if status == TrialStatus.ABANDONED:
237+
trial.mark_abandoned()
238+
elif status == TrialStatus.FAILED:
239+
trial.mark_failed()
240+
else:
241+
trial.mark_completed()
219242

220243
logger.setLevel(level=original_log_level)
221244

222245
experiment.fetch_data()
246+
247+
# Apply EARLY_STOPPED status after data is available, since
248+
# mark_early_stopped() requires data on the trial.
249+
if trial_statuses is not None:
250+
for trial_index, status in trial_statuses.items():
251+
if status == TrialStatus.EARLY_STOPPED:
252+
experiment.trials[trial_index].mark_early_stopped(unsafe=True)
223253
return experiment
224254

225255

@@ -338,14 +368,15 @@ def get_inference_trace(
338368

339369
def get_is_feasible_trace(
340370
experiment: Experiment, optimization_config: OptimizationConfig
341-
) -> list[float]:
371+
) -> list[bool]:
342372
"""Get a trace of feasibility for the experiment.
343373
344374
For batch trials we return True if any arm in a given batch is feasible.
375+
Trials without data (e.g. abandoned or failed) default to False.
345376
"""
346377
df = experiment.lookup_data().df.copy() # Let's not modify the original df
347378
if len(df) == 0:
348-
return []
379+
return [False] * len(experiment.trials)
349380
# Derelativize the optimization config if needed.
350381
optimization_config = derelativize_opt_config(
351382
optimization_config=optimization_config,
@@ -354,7 +385,11 @@ def get_is_feasible_trace(
354385
# Compute feasibility and return feasibility per group
355386
df = _prepare_data_for_trace(df=df, optimization_config=optimization_config)
356387
trial_grouped = df.groupby("trial_index")["feasible"]
357-
return trial_grouped.any().tolist()
388+
feasibility_by_trial = trial_grouped.any().to_dict()
389+
return [
390+
feasibility_by_trial.get(trial_index, False)
391+
for trial_index in sorted(experiment.trials.keys())
392+
]
358393

359394

360395
def get_best_parameters(
@@ -451,8 +486,20 @@ def get_benchmark_result_from_experiment_and_gs(
451486
for new_trial_index, trials in enumerate(trial_completion_order)
452487
}
453488

489+
# Create trial_statuses mapping to preserve trial status in oracle experiment
490+
trial_statuses = {
491+
new_trial_index: (
492+
experiment.trials[next(iter(old_trial_indices))].status
493+
if len(old_trial_indices) == 1
494+
else TrialStatus.COMPLETED
495+
)
496+
for new_trial_index, old_trial_indices in enumerate(trial_completion_order)
497+
}
498+
454499
actual_params_oracle_dummy_experiment = get_oracle_experiment_from_params(
455-
problem=problem, dict_of_dict_of_params=dict_of_dict_of_params
500+
problem=problem,
501+
dict_of_dict_of_params=dict_of_dict_of_params,
502+
trial_statuses=trial_statuses,
456503
)
457504
oracle_trace = np.array(
458505
get_trace(

ax/benchmark/tests/test_benchmark.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
)
7171
from ax.core.experiment import Experiment
7272
from ax.core.objective import MultiObjective
73+
from ax.core.trial_status import TrialStatus
7374
from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy
7475
from ax.generation_strategy.external_generation_node import ExternalGenerationNode
7576
from ax.generation_strategy.generation_strategy import (
@@ -1014,6 +1015,52 @@ def test_get_oracle_experiment_from_params(self) -> None:
10141015
problem=problem, dict_of_dict_of_params={0: {}}
10151016
)
10161017

1018+
with self.subTest("trial_statuses"):
1019+
trial_statuses = {
1020+
0: TrialStatus.COMPLETED,
1021+
1: TrialStatus.ABANDONED,
1022+
}
1023+
experiment = get_oracle_experiment_from_params(
1024+
problem=problem,
1025+
dict_of_dict_of_params={
1026+
0: {"0": near_opt_params},
1027+
1: {"1": other_params},
1028+
},
1029+
trial_statuses=trial_statuses,
1030+
)
1031+
self.assertEqual(len(experiment.trials), 2)
1032+
self.assertTrue(experiment.trials[0].status.is_completed)
1033+
self.assertEqual(experiment.trials[1].status, TrialStatus.ABANDONED)
1034+
1035+
with self.subTest("trial_statuses with FAILED and EARLY_STOPPED"):
1036+
trial_statuses = {
1037+
0: TrialStatus.FAILED,
1038+
1: TrialStatus.EARLY_STOPPED,
1039+
}
1040+
experiment = get_oracle_experiment_from_params(
1041+
problem=problem,
1042+
dict_of_dict_of_params={
1043+
0: {"0": near_opt_params},
1044+
1: {"1": other_params},
1045+
},
1046+
trial_statuses=trial_statuses,
1047+
)
1048+
self.assertEqual(experiment.trials[0].status, TrialStatus.FAILED)
1049+
self.assertEqual(experiment.trials[1].status, TrialStatus.EARLY_STOPPED)
1050+
1051+
with self.subTest("trial_statuses=None defaults to COMPLETED"):
1052+
experiment = get_oracle_experiment_from_params(
1053+
problem=problem,
1054+
dict_of_dict_of_params={
1055+
0: {"0": near_opt_params},
1056+
1: {"1": other_params},
1057+
},
1058+
trial_statuses=None,
1059+
)
1060+
self.assertTrue(
1061+
all(t.status.is_completed for t in experiment.trials.values())
1062+
)
1063+
10171064
def _test_multi_fidelity_or_multi_task(
10181065
self, fidelity_or_task: Literal["fidelity", "task"]
10191066
) -> None:

ax/service/tests/test_best_point.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,36 @@ def test_get_trace(self) -> None:
181181
]
182182
)
183183
exp.attach_data(Data(df=pd.DataFrame.from_records(df_dict2)))
184-
self.assertEqual(get_trace(exp), [2.0, 20.0])
184+
self.assertEqual(get_trace(exp), [2.0, 2.0, 20.0])
185+
186+
def test_get_trace_with_abandoned_trials(self) -> None:
187+
with self.subTest("minimize with abandoned trial"):
188+
exp = get_experiment_with_observations(
189+
observations=[[11], [10], [9], [15], [5]], minimize=True
190+
)
191+
# Mark trial 2 (value=9) as abandoned
192+
exp.trials[2].mark_abandoned(unsafe=True)
193+
194+
# Abandoned trial carries forward the last best value
195+
trace = get_trace(exp)
196+
self.assertEqual(len(trace), 5)
197+
# Trial 0: 11, Trial 1: 10, Trial 2 (abandoned): carry forward 10
198+
# Trial 3: 10 (15 > 10), Trial 4: 5
199+
self.assertEqual(trace, [11, 10, 10, 10, 5])
200+
201+
with self.subTest("maximize with abandoned trial"):
202+
exp = get_experiment_with_observations(
203+
observations=[[1], [3], [2], [5], [4]], minimize=False
204+
)
205+
# Mark trial 1 (value=3) as abandoned
206+
exp.trials[1].mark_abandoned(unsafe=True)
207+
208+
# Abandoned trial carries forward the last best value
209+
trace = get_trace(exp)
210+
self.assertEqual(len(trace), 5)
211+
# Trial 0: 1, Trial 1 (abandoned): carry forward 1,
212+
# Trial 2: 2, Trial 3: 5, Trial 4: 5
213+
self.assertEqual(trace, [1, 1, 2, 5, 5])
185214

186215
def test_get_trace_with_include_status_quo(self) -> None:
187216
with self.subTest("Multi-objective: status quo dominates in some trials"):

ax/service/utils/best_point.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,7 @@ def get_trace(
10611061
10621062
An iteration here refers to a completed or early-stopped (batch) trial.
10631063
There will be one performance metric in the trace for each iteration.
1064+
Abandoned trials carry forward the last best value.
10641065
10651066
Args:
10661067
experiment: The experiment to get the trace for.
@@ -1128,7 +1129,30 @@ def get_trace(
11281129
value_by_trial = trial_grouped.min()
11291130
cumulative_value = np.minimum.accumulate(value_by_trial)
11301131

1131-
return cumulative_value.tolist()
1132+
compact_trace = cumulative_value.tolist()
1133+
1134+
# Expand trace to include ABANDONED trials with carry-forward values.
1135+
data_trial_indices = set(value_by_trial.index)
1136+
expanded_trace = []
1137+
compact_idx = 0
1138+
last_best_value = -float("inf") if maximize else float("inf")
1139+
1140+
for trial_index in sorted(experiment.trials.keys()):
1141+
trial = experiment.trials[trial_index]
1142+
if trial_index in data_trial_indices:
1143+
# Trial has data in compact trace
1144+
if compact_idx < len(compact_trace):
1145+
value = compact_trace[compact_idx]
1146+
expanded_trace.append(value)
1147+
last_best_value = value
1148+
compact_idx += 1
1149+
else:
1150+
# Should not happen, but handle gracefully
1151+
expanded_trace.append(last_best_value)
1152+
elif trial.status == TrialStatus.ABANDONED:
1153+
expanded_trace.append(last_best_value)
1154+
1155+
return expanded_trace
11321156

11331157

11341158
def get_tensor_converter_adapter(

0 commit comments

Comments
 (0)