Skip to content

Commit 544c785

Browse files
Sunny Shenmeta-codesync[bot]
authored andcommitted
AddExecutionViability transform (facebook#4547)
Summary: Pull Request resolved: facebook#4547 Transform that adds failure-awareness capability to Ax optimization. This transform enables Ax to learn from deterministic trial failures (ABANDONED trials) and avoid sampling similar parameter configurations that are likely to fail. It achieves this by: 1. Adding a "execution_viable" metric to experiment data based on trial status - ABANDONED trials get feasibility value of 0.0 (not viable) - Other trials get feasibility value of 1.0 (viable) 2. Adding the execution_viable constraint to the optimization config - The constraint enforces P(execution_viable) >= threshold - This guides the acquisition function to avoid non-viable regions Differential Revision: D85185246
1 parent 6460a1a commit 544c785

3 files changed

Lines changed: 458 additions & 0 deletions

File tree

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from __future__ import annotations
10+
11+
from logging import Logger
12+
from typing import TYPE_CHECKING
13+
14+
import pandas as pd
15+
from ax.adapter.data_utils import ExperimentData
16+
from ax.adapter.transforms.base import Transform
17+
from ax.core.base_trial import BaseTrial, TrialStatus
18+
from ax.core.metric import Metric
19+
from ax.core.observation import ObservationFeatures
20+
from ax.core.optimization_config import OptimizationConfig
21+
from ax.core.outcome_constraint import OutcomeConstraint
22+
from ax.core.types import ComparisonOp
23+
from ax.utils.common.logger import get_logger
24+
25+
if TYPE_CHECKING:
26+
# import as module to make sphinx-autodoc-typehints happy
27+
from ax import adapter as adapter_module # noqa F401
28+
29+
logger: Logger = get_logger(__name__)
30+
31+
EXECUTION_VIABLE_METRIC_NAME = "execution_viable"
32+
33+
34+
class AddExecutionViability(Transform):
35+
"""Transform that adds failure-awareness capability to Ax optimization.
36+
37+
This transform enables Ax to learn from deterministic trial failures (ABANDONED
38+
trials) and avoid sampling similar parameter configurations that are likely to
39+
fail. It achieves this by:
40+
41+
1. Adding an "execution_viable" metric to experiment data based on trial status
42+
- ABANDONED trials get execution_viable value of 0.0 (not viable)
43+
- Other trials get execution_viable value of 1.0 (viable)
44+
45+
2. Adding an execution viability constraint to the optimization config
46+
- The constraint enforces P(execution_viable) >= threshold
47+
- This guides the acquisition function to avoid regions likely to fail
48+
49+
The transform only activates after observing a minimum number of ABANDONED trials
50+
to ensure there is sufficient data to model the failure region. Before reaching
51+
this threshold, the transform acts as a no-op.
52+
53+
Config options:
54+
feasibility_threshold: float (default 0.8)
55+
Minimum probability of execution viability required for new candidates.
56+
min_abandoned_trials: int (default 3)
57+
Minimum number of ABANDONED trials required before the transform activates.
58+
If fewer than this many ABANDONED trials exist, the transform does nothing.
59+
60+
Example usage:
61+
>>> transform = AddExecutionViability(
62+
... config={
63+
... "feasibility_threshold": 0.8,
64+
... "min_abandoned_trials": 3,
65+
... }
66+
... )
67+
>>> # Transform adds execution viability constraint to optimization
68+
>>> new_opt_config = transform.transform_optimization_config(opt_config)
69+
>>> # Transform adds execution_viable metric to data
70+
>>> transformed_data = transform.transform_experiment_data(exp_data)
71+
"""
72+
73+
@property
74+
def min_abandoned_trials(self) -> int:
75+
"""Minimum ABANDONED trials required before the transform activates."""
76+
raw_value = self.config.get("min_abandoned_trials", 3)
77+
return int(raw_value) if isinstance(raw_value, (int, float)) else 3
78+
79+
def _should_activate(
80+
self, adapter: adapter_module.base.Adapter
81+
) -> tuple[bool, int, list[BaseTrial]]:
82+
"""Check if transform should activate based on abandoned trial count.
83+
84+
Returns:
85+
A tuple of (should_activate, abandoned_count, abandoned_trials)
86+
"""
87+
experiment = adapter._experiment
88+
abandoned_trials = experiment.trials_by_status.get(TrialStatus.ABANDONED, [])
89+
abandoned_count = len(abandoned_trials)
90+
should_activate = abandoned_count >= self.min_abandoned_trials
91+
return should_activate, abandoned_count, abandoned_trials
92+
93+
def transform_experiment_data(
94+
self, experiment_data: ExperimentData
95+
) -> ExperimentData:
96+
"""Transform experiment data to add execution viability metrics.
97+
98+
Only activates after observing at least min_abandoned_trials ABANDONED trials.
99+
Returns the original data unchanged if this threshold is not met.
100+
101+
This method handles two types of ABANDONED trials:
102+
1. ABANDONED trials WITH data: These already exist in
103+
experiment_data and will get execution_viable = 0 added to their
104+
existing observations.
105+
2. ABANDONED trials WITHOUT data: These are missing from
106+
experiment_data (e.g., trials that failed due to metric errors).
107+
We add synthetic observations for these with execution_viable = 0 so
108+
the model can learn about regions likely to fail.
109+
"""
110+
if self.adapter is None:
111+
raise ValueError(
112+
"Adapter must be provided for using feasibility constraints."
113+
)
114+
115+
adapter = self.adapter
116+
should_activate, abandoned_count, abandoned_trials = self._should_activate(
117+
adapter
118+
)
119+
120+
if not should_activate:
121+
logger.debug(
122+
f"AddExecutionViability transform inactive: "
123+
f"only {abandoned_count} ABANDONED trials observed "
124+
f"(need {self.min_abandoned_trials}). Returning original data."
125+
)
126+
return experiment_data
127+
128+
experiment = adapter._experiment
129+
130+
# Proceed with adding execution viability metric
131+
obs_data = experiment_data.observation_data
132+
arm_data = experiment_data.arm_data
133+
134+
# Step 1: Add execution viability metric to existing observations
135+
# Create a mapping from trial_index to execution_viable for efficiency
136+
trial_to_viability = {
137+
trial_idx: float(
138+
experiment.trials[trial_idx].status != TrialStatus.ABANDONED
139+
)
140+
for trial_idx in obs_data.index.get_level_values("trial_index").unique()
141+
}
142+
trial_indices = obs_data.index.get_level_values("trial_index")
143+
viability_values = trial_indices.map(trial_to_viability)
144+
obs_data[("mean", EXECUTION_VIABLE_METRIC_NAME)] = viability_values
145+
obs_data[("sem", EXECUTION_VIABLE_METRIC_NAME)] = float("nan")
146+
147+
# Step 2: Identify ABANDONED trials that are NOT in the observation data
148+
trials_in_data = set(obs_data.index.get_level_values("trial_index").unique())
149+
abandoned_trials_without_data = [
150+
trial for trial in abandoned_trials if trial.index not in trials_in_data
151+
]
152+
153+
# Step 3: Add observations for ABANDONED trials without data
154+
if abandoned_trials_without_data:
155+
new_rows = []
156+
new_arm_rows = []
157+
158+
for trial in abandoned_trials_without_data:
159+
# Each trial can have multiple arms
160+
for arm in trial.arms:
161+
trial_idx = trial.index
162+
arm_name = arm.name
163+
164+
new_row_data = {
165+
"trial_index": trial_idx,
166+
"arm_name": arm_name,
167+
("mean", EXECUTION_VIABLE_METRIC_NAME): 0.0,
168+
("sem", EXECUTION_VIABLE_METRIC_NAME): float("nan"),
169+
}
170+
171+
# Add NaN values for all other metrics that exist in obs_data
172+
for col in obs_data.columns:
173+
if col not in [
174+
("mean", EXECUTION_VIABLE_METRIC_NAME),
175+
("sem", EXECUTION_VIABLE_METRIC_NAME),
176+
]:
177+
new_row_data[col] = float("nan")
178+
179+
new_rows.append(new_row_data)
180+
181+
# Also add to arm_data
182+
arm_row_data = dict(arm.parameters)
183+
metadata_raw = trial._get_candidate_metadata(arm.name)
184+
metadata = metadata_raw if metadata_raw is not None else {}
185+
if (
186+
"trial_completion_timestamp" not in metadata
187+
and trial._time_completed is not None
188+
):
189+
metadata["trial_completion_timestamp"] = (
190+
trial._time_completed.timestamp()
191+
)
192+
arm_row_data["metadata"] = metadata # pyre-ignore[6]
193+
new_arm_rows.append(
194+
{"trial_index": trial_idx, "arm_name": arm_name, **arm_row_data}
195+
)
196+
197+
if new_rows:
198+
new_obs_df = pd.DataFrame(new_rows)
199+
new_obs_df = new_obs_df.set_index(["trial_index", "arm_name"])
200+
201+
obs_data = pd.concat([obs_data, new_obs_df])
202+
203+
new_arm_df = pd.DataFrame(new_arm_rows)
204+
new_arm_df = new_arm_df.set_index(["trial_index", "arm_name"])
205+
arm_data = pd.concat([arm_data, new_arm_df])
206+
207+
logger.debug(
208+
f"AddExecutionViability: Added synthetic observations for "
209+
f"{len(abandoned_trials_without_data)} ABANDONED trials "
210+
"without data"
211+
)
212+
213+
logger.debug(
214+
f"AddExecutionViability transform active: "
215+
f"{abandoned_count} ABANDONED trials observed "
216+
f"(threshold: {self.min_abandoned_trials})"
217+
)
218+
219+
return ExperimentData(
220+
arm_data=arm_data,
221+
observation_data=obs_data,
222+
)
223+
224+
def transform_optimization_config(
225+
self,
226+
optimization_config: OptimizationConfig,
227+
adapter: adapter_module.base.Adapter | None = None,
228+
fixed_features: ObservationFeatures | None = None,
229+
) -> OptimizationConfig:
230+
"""Transform optimization config to add execution viability constraint.
231+
232+
Only activates after observing at least min_abandoned_trials ABANDONED trials.
233+
Returns the original config unchanged if this threshold is not met.
234+
"""
235+
adapter = adapter or self.adapter
236+
if adapter is None:
237+
raise ValueError("Adapter must be provided for using feasibility.")
238+
239+
should_activate, abandoned_count, _ = self._should_activate(adapter)
240+
241+
if not should_activate:
242+
logger.debug(
243+
f"AddExecutionViability transform inactive: "
244+
f"only {abandoned_count} ABANDONED trials observed "
245+
f"(need {self.min_abandoned_trials}). Returning original config."
246+
)
247+
return optimization_config
248+
249+
# Proceed with adding execution viability constraint
250+
viability_metric = Metric(
251+
name=EXECUTION_VIABLE_METRIC_NAME,
252+
lower_is_better=False,
253+
)
254+
viability_constraint = OutcomeConstraint(
255+
metric=viability_metric,
256+
op=ComparisonOp.GEQ,
257+
bound=self.config.get("feasibility_threshold", 0.8), # pyre-ignore [6]
258+
relative=False,
259+
)
260+
261+
# Create a new list with existing constraints plus the viability constraint
262+
new_outcome_constraints = list(optimization_config.outcome_constraints)
263+
new_outcome_constraints.append(viability_constraint)
264+
265+
transformed_opt_config = optimization_config.clone_with_args(
266+
outcome_constraints=new_outcome_constraints,
267+
)
268+
269+
# Add viability metric to outcomes if not already present
270+
if viability_metric.name not in adapter.outcomes:
271+
adapter.outcomes.append(viability_metric.name)
272+
273+
logger.debug(
274+
f"AddExecutionViability constraint active: "
275+
f"{abandoned_count} ABANDONED trials observed "
276+
f"(threshold: {self.min_abandoned_trials})"
277+
)
278+
279+
return transformed_opt_config

0 commit comments

Comments
 (0)