Skip to content

Commit ec9ec5f

Browse files
mpolson64facebook-github-bot
authored andcommitted
Make ParameterConstraint store its own inequality string (#4658)
Summary: Move some of the string parsing from ax/api into core Ax to allow a ParameterConstraint to be defined and constructed from its inequality string. This will allow more flexibility in the future, but is primarily motivated by a want to store only the inequality string when saving ParameterConstraints and to eliminiate unnnecessary subclasses SumConstraint and OrderConstraint Differential Revision: D88880607
1 parent 1f75852 commit ec9ec5f

23 files changed

Lines changed: 254 additions & 179 deletions

ax/adapter/tests/test_torch_moo_adapter.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,7 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
437437
)
438438
fixed_features = ObservationFeatures(parameters={"x1": 0.0})
439439
search_space = exp.search_space.clone()
440-
param_constraints = [
441-
ParameterConstraint(constraint_dict={"x1": 1.0}, bound=10.0)
442-
]
440+
param_constraints = [ParameterConstraint(inequality="x1 <= 10")]
443441
search_space.add_parameter_constraints(param_constraints)
444442
oc = none_throws(exp.optimization_config).clone()
445443
oc.objective._objectives[0].minimize = True

ax/adapter/transforms/tests/test_choice_encode_transform.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ def setUp(self) -> None:
6868
sort_values=False,
6969
),
7070
],
71-
parameter_constraints=[
72-
ParameterConstraint(constraint_dict={"x": -0.5, "a": 1}, bound=0.5)
73-
],
71+
parameter_constraints=[ParameterConstraint(inequality="-0.5*x + a <= 0.5")],
7472
)
7573
self.t = self.t_class(search_space=self.search_space)
7674
input_params: TParameterization = {

ax/adapter/transforms/tests/test_one_hot_transform.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ def setUp(self) -> None:
4646
is_ordered=True,
4747
),
4848
],
49-
parameter_constraints=[
50-
ParameterConstraint(constraint_dict={"x": -0.5, "a": 1}, bound=0.5)
51-
],
49+
parameter_constraints=[ParameterConstraint(inequality="-0.5*x + a <= 0.5")],
5250
)
5351
self.t = OneHot(search_space=self.search_space)
5452
self.t2 = OneHot(

ax/adapter/transforms/tests/test_unit_x_transform.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def setUp(self) -> None:
4646
),
4747
],
4848
parameter_constraints=[
49-
ParameterConstraint(constraint_dict={"x": -0.5, "y": 1}, bound=0.5),
50-
ParameterConstraint(constraint_dict={"x": -0.5, "a": 1}, bound=0.5),
49+
ParameterConstraint(inequality="-0.5*x + y <= 0.5"),
50+
ParameterConstraint(inequality="-0.5*x + a <= 0.5"),
5151
],
5252
)
5353
self.t = UnitX(search_space=self.search_space)
@@ -157,8 +157,8 @@ def test_TransformNewSearchSpace(self) -> None:
157157
),
158158
],
159159
parameter_constraints=[
160-
ParameterConstraint(constraint_dict={"x": -0.5, "y": 1}, bound=0.5),
161-
ParameterConstraint(constraint_dict={"x": -0.5, "a": 1}, bound=0.5),
160+
ParameterConstraint(inequality="-0.5*x + y <= 0.5"),
161+
ParameterConstraint(inequality="-0.5*x + a <= 0.5"),
162162
],
163163
)
164164
self.t.transform_search_space(new_ss)

ax/adapter/transforms/unit_x.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,14 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
9898
bound -= w * l
9999
else:
100100
constraint_dict[p_name] = w
101+
102+
expr = " + ".join(
103+
f"{coeff} * {param}" for param, coeff in constraint_dict.items()
104+
)
101105
new_constraints.append(
102-
ParameterConstraint(constraint_dict=constraint_dict, bound=bound)
106+
ParameterConstraint(
107+
inequality=f"{expr} <= {bound}",
108+
)
103109
)
104110
search_space.set_parameter_constraints(new_constraints)
105111
return search_space

ax/analysis/healthcheck/tests/test_search_space_analysis.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,7 @@ def test_search_space_boundary_proportions(self) -> None:
9797
),
9898
],
9999
parameter_constraints=[
100-
ParameterConstraint(
101-
constraint_dict={"float_range_1": 1.0, "float_range_2": 1.0},
102-
bound=4.0,
103-
)
100+
ParameterConstraint(inequality="float_range_1 + float_range_2 <= 4")
104101
],
105102
)
106103

ax/api/tests/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_configure_experiment(self) -> None:
111111
],
112112
parameter_constraints=[
113113
ParameterConstraint(
114-
constraint_dict={"int_param": 1, "float_param": -1}, bound=0
114+
inequality="int_param <= float_param",
115115
)
116116
],
117117
),

ax/api/utils/instantiation/from_string.py

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from collections.abc import Sequence
99

1010
from ax.core.map_metric import MapMetric
11-
1211
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
1312
from ax.core.optimization_config import (
1413
MultiObjectiveOptimizationConfig,
@@ -20,14 +19,13 @@
2019
OutcomeConstraint,
2120
ScalarizedOutcomeConstraint,
2221
)
23-
from ax.core.parameter_constraint import ParameterConstraint
2422
from ax.exceptions.core import UserInputError
2523
from ax.utils.common.string_utils import sanitize_name, unsanitize_name
24+
from ax.utils.common.sympy import extract_coefficient_dict_from_inequality
2625
from pyre_extensions import assert_is_instance, none_throws
2726
from sympy.core.add import Add
2827
from sympy.core.expr import Expr
2928
from sympy.core.mul import Mul
30-
from sympy.core.relational import GreaterThan, LessThan
3129
from sympy.core.symbol import Symbol
3230
from sympy.core.sympify import sympify
3331

@@ -95,35 +93,6 @@ def optimization_config_from_string(
9593
)
9694

9795

98-
def parse_parameter_constraint(constraint_str: str) -> ParameterConstraint:
99-
"""
100-
Parse a parameter constraint string into a ParameterConstraint object using SymPy.
101-
Currently only supports linear constraints of the form "a * x + b * y >= k" or
102-
"a * x + b * y <= k".
103-
"""
104-
coefficient_dict = _extract_coefficient_dict_from_inequality(
105-
inequality_str=constraint_str
106-
)
107-
108-
# Iterate through the coefficients to extract the parameter names and weights and
109-
# the bound
110-
constraint_dict = {}
111-
bound = 0
112-
for term, coefficient in coefficient_dict.items():
113-
if term.is_symbol:
114-
constraint_dict[unsanitize_name(term.name)] = coefficient
115-
elif term.is_number:
116-
# Invert because we are "moving" the bound to the right hand side
117-
bound = -1 * coefficient
118-
else:
119-
raise UserInputError(
120-
"Only linear inequality parameter constraints are supported, found "
121-
f"{constraint_str}"
122-
)
123-
124-
return ParameterConstraint(constraint_dict=constraint_dict, bound=bound)
125-
126-
12796
def parse_objective(objective_str: str) -> Objective:
12897
"""
12998
Parse an objective string into an Objective object using SymPy.
@@ -154,7 +123,7 @@ def parse_outcome_constraint(constraint_str: str) -> OutcomeConstraint:
154123
multiply your bound by "baseline". For example "qps >= 0.95 * baseline" will
155124
constrain such that the QPS is at least 95% of the baseline arm's QPS.
156125
"""
157-
coefficient_dict = _extract_coefficient_dict_from_inequality(
126+
coefficient_dict = extract_coefficient_dict_from_inequality(
158127
inequality_str=constraint_str
159128
)
160129

@@ -248,31 +217,3 @@ def _create_single_objective(expression: Expr) -> Objective:
248217
)
249218

250219
raise UserInputError(f"Only linear objectives are supported, found {expression}.")
251-
252-
253-
def _extract_coefficient_dict_from_inequality(
254-
inequality_str: str,
255-
) -> dict[Symbol, float]:
256-
"""
257-
Use SymPy to parse a string into an inequality, invert if necessary to enforce a
258-
less-than relationship, move all terms to the left side, and return the
259-
coefficients as a dictionary. This is useful for parsing parameter and outcome
260-
constraints.
261-
"""
262-
# Parse the constraint string into a SymPy inequality
263-
inequality = sympify(sanitize_name(inequality_str))
264-
265-
# Check the SymPy object is a valid inequality
266-
if not isinstance(inequality, GreaterThan | LessThan):
267-
raise UserInputError(f"Expected an inequality, found {inequality_str}")
268-
269-
# Move all terms to the left side of the inequality and invert if necessary to
270-
# enforce a less-than relationship
271-
if isinstance(inequality, LessThan):
272-
expression = inequality.lhs - inequality.rhs
273-
else:
274-
expression = inequality.rhs - inequality.lhs
275-
276-
return {
277-
key: float(value) for key, value in expression.as_coefficients_dict().items()
278-
}

ax/api/utils/instantiation/from_struct.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
# pyre-strict
77

88
from ax.api.utils.instantiation.from_config import parameter_from_config
9-
from ax.api.utils.instantiation.from_string import parse_parameter_constraint
109
from ax.api.utils.structs import ExperimentStruct
1110
from ax.core.evaluations_to_data import DataType
1211
from ax.core.experiment import Experiment
13-
from ax.core.parameter_constraint import validate_constraint_parameters
12+
from ax.core.parameter_constraint import (
13+
ParameterConstraint,
14+
validate_constraint_parameters,
15+
)
1416
from ax.core.search_space import SearchSpace
1517

1618

@@ -22,8 +24,8 @@ def experiment_from_struct(struct: ExperimentStruct) -> Experiment:
2224
]
2325

2426
constraints = [
25-
parse_parameter_constraint(constraint_str=constraint_str)
26-
for constraint_str in struct.parameter_constraints
27+
ParameterConstraint(inequality=inequality)
28+
for inequality in struct.parameter_constraints
2729
]
2830

2931
# Ensure that all ParameterConstraints are valid and acting on existing parameters

ax/api/utils/instantiation/tests/test_from_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def test_experiment_from_config(self) -> None:
278278
],
279279
parameter_constraints=[
280280
ParameterConstraint(
281-
constraint_dict={"int_param": 1, "float_param": -1}, bound=0
281+
inequality="int_param <= float_param",
282282
)
283283
],
284284
),
@@ -340,7 +340,7 @@ def test_experiment_from_config(self) -> None:
340340
],
341341
parameter_constraints=[
342342
ParameterConstraint(
343-
constraint_dict={"int_param": 1, "float_param": -1}, bound=0
343+
inequality="int_param <= float_param",
344344
)
345345
],
346346
),

0 commit comments

Comments
 (0)