Skip to content

Commit 2794870

Browse files
Sunny Shenfacebook-github-bot
authored andcommitted
AddExecutionViability transform (#4547)
Summary: Transform that adds failure-awareness capability to Ax optimization. This transform enables Ax to learn from deterministic trial failures (ABANDONED trials) and avoid sampling similar parameter configurations that are likely to fail. It achieves this by: 1. Adding a "execution_viable" metric to experiment data based on trial status - ABANDONED trials get feasibility value of 0.0 (not viable) - Other trials get feasibility value of 1.0 (viable) 2. Adding the execution_viable constraint to the optimization config - The constraint enforces P(execution_viable) >= threshold - This guides the acquisition function to avoid non-viable regions Reviewed By: saitcakmak Differential Revision: D85185246
1 parent 34f46d1 commit 2794870

3 files changed

Lines changed: 459 additions & 0 deletions

File tree

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from __future__ import annotations
10+
11+
from logging import Logger
12+
from typing import TYPE_CHECKING
13+
14+
import pandas as pd
15+
from ax.adapter.data_utils import ExperimentData
16+
from ax.adapter.transforms.base import Transform
17+
from ax.core.base_trial import BaseTrial, TrialStatus
18+
from ax.core.metric import Metric
19+
from ax.core.observation import ObservationFeatures
20+
from ax.core.optimization_config import OptimizationConfig
21+
from ax.core.outcome_constraint import OutcomeConstraint
22+
from ax.core.types import ComparisonOp
23+
from ax.utils.common.constants import Keys
24+
from ax.utils.common.logger import get_logger
25+
26+
if TYPE_CHECKING:
27+
# import as module to make sphinx-autodoc-typehints happy
28+
from ax import adapter as adapter_module # noqa F401
29+
30+
logger: Logger = get_logger(__name__)
31+
32+
EXECUTION_VIABLE_METRIC_NAME = "execution_viable"
33+
34+
35+
class AddExecutionViability(Transform):
36+
"""Transform that adds failure-awareness capability to Ax optimization.
37+
38+
This transform enables Ax to learn from deterministic trial failures (ABANDONED
39+
trials) and avoid sampling similar parameter configurations that are likely to
40+
fail. It achieves this by:
41+
42+
1. Adding an "execution_viable" metric to experiment data based on trial status
43+
- ABANDONED trials get execution_viable value of 0.0 (not viable)
44+
- Other trials get execution_viable value of 1.0 (viable)
45+
46+
2. Adding an execution viability constraint to the optimization config
47+
- The constraint enforces P(execution_viable) >= threshold
48+
- This guides the acquisition function to avoid regions likely to fail
49+
50+
The transform only activates after observing a minimum number of ABANDONED trials
51+
to ensure there is sufficient data to model the failure region. Before reaching
52+
this threshold, the transform acts as a no-op.
53+
54+
Config options:
55+
feasibility_threshold: float (default 0.8)
56+
Minimum probability of execution viability required for new candidates.
57+
min_abandoned_trials: int (default 3)
58+
Minimum number of ABANDONED trials required before the transform activates.
59+
If fewer than this many ABANDONED trials exist, the transform does nothing.
60+
61+
Example usage:
62+
>>> transform = AddExecutionViability(
63+
... config={
64+
... "feasibility_threshold": 0.8,
65+
... "min_abandoned_trials": 3,
66+
... }
67+
... )
68+
>>> # Transform adds execution viability constraint to optimization
69+
>>> new_opt_config = transform.transform_optimization_config(opt_config)
70+
>>> # Transform adds execution_viable metric to data
71+
>>> transformed_data = transform.transform_experiment_data(exp_data)
72+
"""
73+
74+
@property
75+
def min_abandoned_trials(self) -> int:
76+
"""Minimum ABANDONED trials required before the transform activates."""
77+
raw_value = self.config.get("min_abandoned_trials", 3)
78+
return int(raw_value) if isinstance(raw_value, (int, float)) else 3
79+
80+
def _should_activate(
81+
self, adapter: adapter_module.base.Adapter
82+
) -> tuple[bool, int, list[BaseTrial]]:
83+
"""Check if transform should activate based on abandoned trial count.
84+
85+
Returns:
86+
A tuple of (should_activate, abandoned_count, abandoned_trials)
87+
"""
88+
experiment = adapter._experiment
89+
abandoned_trials = experiment.trials_by_status.get(TrialStatus.ABANDONED, [])
90+
abandoned_count = len(abandoned_trials)
91+
should_activate = abandoned_count >= self.min_abandoned_trials
92+
return should_activate, abandoned_count, abandoned_trials
93+
94+
def transform_experiment_data(
95+
self, experiment_data: ExperimentData
96+
) -> ExperimentData:
97+
"""Transform experiment data to add execution viability metrics.
98+
99+
Only activates after observing at least min_abandoned_trials ABANDONED trials.
100+
Returns the original data unchanged if this threshold is not met.
101+
102+
This method handles two types of ABANDONED trials:
103+
1. ABANDONED trials WITH data: These already exist in
104+
experiment_data and will get execution_viable = 0 added to their
105+
existing observations.
106+
2. ABANDONED trials WITHOUT data: These are missing from
107+
experiment_data (e.g., trials that failed due to metric errors).
108+
We add synthetic observations for these with execution_viable = 0 so
109+
the model can learn about regions likely to fail.
110+
"""
111+
if self.adapter is None:
112+
raise ValueError(
113+
"Adapter must be provided for AddExecutionViability transform."
114+
)
115+
116+
adapter = self.adapter
117+
should_activate, abandoned_count, abandoned_trials = self._should_activate(
118+
adapter
119+
)
120+
121+
if not should_activate:
122+
logger.debug(
123+
f"AddExecutionViability transform inactive: "
124+
f"only {abandoned_count} ABANDONED trials observed "
125+
f"(need {self.min_abandoned_trials}). Returning original data."
126+
)
127+
return experiment_data
128+
129+
# Proceed with adding execution viability metric
130+
obs_data = experiment_data.observation_data
131+
arm_data = experiment_data.arm_data
132+
133+
# Add rows for abandoned trials without data
134+
trials_in_data = set(obs_data.index.get_level_values("trial_index").unique())
135+
has_step = "step" in obs_data.index.names
136+
new_arm_entries = []
137+
for trial in abandoned_trials:
138+
if trial.index not in trials_in_data:
139+
for arm in trial.arms:
140+
idx = (
141+
(trial.index, arm.name, 1.0)
142+
if has_step
143+
else (trial.index, arm.name)
144+
)
145+
obs_data.loc[idx, :] = float("nan")
146+
147+
arm_row_data = dict(arm.parameters)
148+
metadata_raw = trial._get_candidate_metadata(arm.name)
149+
metadata = metadata_raw if metadata_raw is not None else {}
150+
if (
151+
Keys.TRIAL_COMPLETION_TIMESTAMP not in metadata
152+
and trial._time_completed is not None
153+
):
154+
metadata[Keys.TRIAL_COMPLETION_TIMESTAMP] = (
155+
trial._time_completed.timestamp()
156+
)
157+
arm_row_data["metadata"] = metadata # pyre-ignore[6]
158+
new_arm_entries.append(
159+
{
160+
"trial_index": trial.index,
161+
"arm_name": arm.name,
162+
**arm_row_data,
163+
}
164+
)
165+
166+
if new_arm_entries:
167+
new_arm_df = pd.DataFrame(new_arm_entries).set_index(
168+
["trial_index", "arm_name"]
169+
)
170+
arm_data = pd.concat([arm_data, new_arm_df])
171+
logger.debug(
172+
f"AddExecutionViability: Added synthetic observations for "
173+
f"{len(new_arm_entries)} arms from ABANDONED trials without data"
174+
)
175+
176+
# Assign viability for ALL rows in one shot
177+
trial_indices = obs_data.index.get_level_values("trial_index")
178+
abandoned_set = {t.index for t in abandoned_trials}
179+
obs_data[("mean", EXECUTION_VIABLE_METRIC_NAME)] = [
180+
0.0 if idx in abandoned_set else 1.0 for idx in trial_indices
181+
]
182+
obs_data[("sem", EXECUTION_VIABLE_METRIC_NAME)] = float("nan")
183+
184+
logger.debug(
185+
f"AddExecutionViability transform active: "
186+
f"{abandoned_count} ABANDONED trials observed "
187+
f"(threshold: {self.min_abandoned_trials})"
188+
)
189+
190+
return ExperimentData(
191+
arm_data=arm_data,
192+
observation_data=obs_data,
193+
)
194+
195+
def transform_optimization_config(
196+
self,
197+
optimization_config: OptimizationConfig,
198+
adapter: adapter_module.base.Adapter | None = None,
199+
fixed_features: ObservationFeatures | None = None,
200+
) -> OptimizationConfig:
201+
"""Transform optimization config to add execution viability constraint.
202+
203+
Only activates after observing at least min_abandoned_trials ABANDONED trials.
204+
Returns the original config unchanged if this threshold is not met.
205+
"""
206+
adapter = adapter or self.adapter
207+
if adapter is None:
208+
raise ValueError(
209+
"Adapter must be provided for AddExecutionViability transform."
210+
)
211+
212+
should_activate, abandoned_count, _ = self._should_activate(adapter)
213+
214+
if not should_activate:
215+
logger.debug(
216+
f"AddExecutionViability transform inactive: "
217+
f"only {abandoned_count} ABANDONED trials observed "
218+
f"(need {self.min_abandoned_trials}). Returning original config."
219+
)
220+
return optimization_config
221+
222+
# Proceed with adding execution viability constraint
223+
viability_metric = Metric(
224+
name=EXECUTION_VIABLE_METRIC_NAME,
225+
lower_is_better=False,
226+
)
227+
viability_constraint = OutcomeConstraint(
228+
metric=viability_metric,
229+
op=ComparisonOp.GEQ,
230+
bound=self.config.get("feasibility_threshold", 0.8), # pyre-ignore [6]
231+
relative=False,
232+
)
233+
234+
# Create a new list with existing constraints plus the viability constraint
235+
new_outcome_constraints = list(optimization_config.outcome_constraints)
236+
new_outcome_constraints.append(viability_constraint)
237+
238+
transformed_opt_config = optimization_config.clone_with_args(
239+
outcome_constraints=new_outcome_constraints,
240+
)
241+
242+
# Add viability metric to outcomes if not already present
243+
if viability_metric.name not in adapter.outcomes:
244+
adapter.outcomes.append(viability_metric.name)
245+
246+
logger.debug(
247+
f"AddExecutionViability constraint active: "
248+
f"{abandoned_count} ABANDONED trials observed "
249+
f"(threshold: {self.min_abandoned_trials})"
250+
)
251+
252+
return transformed_opt_config

0 commit comments

Comments
 (0)