Skip to content

Commit f9a9fd6

Browse files
Jelena Markovic-Voronovfacebook-github-bot
Jelena Markovic-Voronov
authored andcommitted
constraints feasibilty via GPs (#3152)
Summary: Pull Request resolved: #3152 Warn users if their constraints aren't satisfied above the given threshold for any of the arms. The constraints feasibility is computed using the GP model fit and the user provided constraint bounds. Reviewed By: danielcohenlive, Balandat Differential Revision: D66398437 fbshipit-source-id: 4dc59b6fbf296b1a659fcb951e0730a1a8184320
1 parent d709c5d commit f9a9fd6

File tree

4 files changed

+426
-0
lines changed

4 files changed

+426
-0
lines changed

ax/analysis/healthcheck/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from ax.analysis.healthcheck.can_generate_candidates import (
99
CanGenerateCandidatesAnalysis,
1010
)
11+
12+
from ax.analysis.healthcheck.constraints_feasibility import (
13+
ConstraintsFeasibilityAnalysis,
14+
)
1115
from ax.analysis.healthcheck.healthcheck_analysis import (
1216
HealthcheckAnalysis,
1317
HealthcheckAnalysisCard,
@@ -16,6 +20,7 @@
1620
from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates
1721

1822
__all__ = [
23+
"ConstraintsFeasibilityAnalysis",
1924
"CanGenerateCandidatesAnalysis",
2025
"HealthcheckAnalysis",
2126
"HealthcheckAnalysisCard",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
import json
9+
from typing import Tuple
10+
11+
import pandas as pd
12+
13+
from ax.analysis.analysis import AnalysisCardLevel
14+
15+
from ax.analysis.healthcheck.healthcheck_analysis import (
16+
HealthcheckAnalysis,
17+
HealthcheckAnalysisCard,
18+
HealthcheckStatus,
19+
)
20+
from ax.analysis.plotly.arm_effects.utils import get_predictions_by_arm
21+
from ax.analysis.plotly.utils import is_predictive
22+
from ax.core.experiment import Experiment
23+
from ax.core.generation_strategy_interface import GenerationStrategyInterface
24+
from ax.core.optimization_config import OptimizationConfig
25+
from ax.exceptions.core import UserInputError
26+
from ax.modelbridge.base import ModelBridge
27+
from ax.modelbridge.generation_strategy import GenerationStrategy
28+
from ax.modelbridge.transforms.derelativize import Derelativize
29+
from ax.utils.common.typeutils import checked_cast
30+
from pyre_extensions import none_throws
31+
32+
33+
class ConstraintsFeasibilityAnalysis(HealthcheckAnalysis):
34+
"""
35+
Analysis for checking the feasibility of the constraints for the experiment.
36+
A constraint is considered feasible if the probability of constraints violation
37+
is below the threshold for at least one arm.
38+
"""
39+
40+
def compute(
41+
self,
42+
experiment: Experiment | None = None,
43+
generation_strategy: GenerationStrategyInterface | None = None,
44+
prob_threshold: float = 0.90,
45+
) -> HealthcheckAnalysisCard:
46+
r"""
47+
Compute the feasibility of the constraints for the experiment.
48+
49+
Args:
50+
experiment: Ax experiment.
51+
generation_strategy: Ax generation strategy.
52+
prob_threhshold: Threshold for the probability of constraint violation.
53+
Constraints are considered feasible if the probability of constraint
54+
violation is below the threshold for at least one arm.
55+
56+
Returns:
57+
A HealthcheckAnalysisCard object with the information on infeasible metrics,
58+
i.e., metrics for which the constraints are infeasible for all test groups
59+
(arms).
60+
"""
61+
status = HealthcheckStatus.PASS
62+
subtitle = "All constraints are feasible."
63+
title_status = "Success"
64+
level = AnalysisCardLevel.LOW
65+
df = pd.DataFrame({"status": [status]})
66+
67+
if experiment is None:
68+
raise UserInputError(
69+
"ConstraintsFeasibilityAnalysis requires an Experiment."
70+
)
71+
72+
if experiment.optimization_config is None:
73+
raise UserInputError(
74+
"ConstraintsFeasibilityAnalysis requires an Experiment with an "
75+
"optimization config."
76+
)
77+
78+
if (
79+
experiment.optimization_config.outcome_constraints is None
80+
or len(experiment.optimization_config.outcome_constraints) == 0
81+
):
82+
subtitle = "No constraints are specified."
83+
return HealthcheckAnalysisCard(
84+
name="ConstraintsFeasibility",
85+
title=f"Ax Constraints Feasibility {title_status}",
86+
blob=json.dumps({"status": status}),
87+
subtitle=subtitle,
88+
df=df,
89+
level=level,
90+
)
91+
92+
if generation_strategy is None:
93+
raise UserInputError(
94+
"ConstraintsFeasibilityAnalysis requires a GenerationStrategy."
95+
)
96+
generation_strategy = checked_cast(
97+
GenerationStrategy,
98+
generation_strategy,
99+
exception=UserInputError(
100+
"ConstraintsFeasibilityAnalysis requires a GenerationStrategy."
101+
),
102+
)
103+
104+
if generation_strategy.model is None:
105+
generation_strategy._fit_current_model(data=experiment.lookup_data())
106+
107+
model = none_throws(generation_strategy.model)
108+
if not is_predictive(model=model):
109+
raise UserInputError(
110+
"ConstraintsFeasibility requires a GenerationStrategy which is "
111+
"in a state where the current model supports prediction. "
112+
"The current model is {model._model_key} and does not support "
113+
"prediction."
114+
)
115+
optimization_config = checked_cast(
116+
OptimizationConfig, experiment.optimization_config
117+
)
118+
constraints_feasible, df = constraints_feasibility(
119+
optimization_config=optimization_config,
120+
model=model,
121+
prob_threshold=prob_threshold,
122+
)
123+
df["status"] = status
124+
125+
if not constraints_feasible:
126+
status = HealthcheckStatus.WARNING
127+
subtitle = (
128+
"Constraints are infeasible for all test groups (arms) with respect "
129+
f"to the probability threshold {prob_threshold}. "
130+
"We suggest relaxing the constraint bounds for the constraints."
131+
)
132+
title_status = "Warning"
133+
df.loc[
134+
df["overall_probability_constraints_violated"] > prob_threshold,
135+
"status",
136+
] = status
137+
138+
return HealthcheckAnalysisCard(
139+
name="ConstraintsFeasibility",
140+
title=f"Ax Constraints Feasibility {title_status}",
141+
blob=json.dumps({"status": status}),
142+
subtitle=subtitle,
143+
df=df,
144+
level=level,
145+
)
146+
147+
148+
def constraints_feasibility(
149+
optimization_config: OptimizationConfig,
150+
model: ModelBridge,
151+
prob_threshold: float = 0.99,
152+
) -> Tuple[bool, pd.DataFrame]:
153+
r"""
154+
Check the feasibility of the constraints for the experiment.
155+
156+
Args:
157+
optimization_config: Ax optimization config.
158+
model: Ax model to use for predictions.
159+
prob_threshold: Threshold for the probability of constraint violation.
160+
161+
Returns:
162+
A tuple of a boolean indicating whether the constraints are feasible and a
163+
dataframe with information on the probabilities of constraints violation for
164+
each arm.
165+
"""
166+
if (optimization_config.outcome_constraints is None) or (
167+
len(optimization_config.outcome_constraints) == 0
168+
):
169+
raise UserInputError("No constraints are specified.")
170+
171+
derel_optimization_config = optimization_config
172+
outcome_constraints = optimization_config.outcome_constraints
173+
174+
if any(constraint.relative for constraint in outcome_constraints):
175+
derel_optimization_config = Derelativize().transform_optimization_config(
176+
optimization_config=optimization_config,
177+
modelbridge=model,
178+
)
179+
180+
constraint_metric_name = [
181+
constraint.metric.name
182+
for constraint in derel_optimization_config.outcome_constraints
183+
][0]
184+
185+
arm_dict = get_predictions_by_arm(
186+
model=model,
187+
metric_name=constraint_metric_name,
188+
outcome_constraints=derel_optimization_config.outcome_constraints,
189+
)
190+
191+
df = pd.DataFrame(arm_dict)
192+
constraints_feasible = True
193+
if all(
194+
arm_info["overall_probability_constraints_violated"] > prob_threshold
195+
for arm_info in arm_dict
196+
if arm_info["arm_name"] != model.status_quo_name
197+
):
198+
constraints_feasible = False
199+
200+
return constraints_feasible, df

0 commit comments

Comments
 (0)