Skip to content

Commit 29dcd19

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Replace OrderConstraint with ParameterConstraint (#4659)
Summary: Pull Request resolved: #4659 Any OrderConstraint(a, b) can be rewritten ParameterConstraint("a <= b") which is both more clear and removes polymorphism ahead of storage rework Reviewed By: saitcakmak Differential Revision: D88890559
1 parent 001cbc7 commit 29dcd19

18 files changed

Lines changed: 63 additions & 286 deletions

ax/adapter/tests/test_random_adapter.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@
1818
from ax.core.metric import Metric
1919
from ax.core.observation import ObservationFeatures
2020
from ax.core.parameter import ParameterType, RangeParameter
21-
from ax.core.parameter_constraint import (
22-
OrderConstraint,
23-
ParameterConstraint,
24-
SumConstraint,
25-
)
21+
from ax.core.parameter_constraint import ParameterConstraint, SumConstraint
2622
from ax.core.search_space import SearchSpace
2723
from ax.exceptions.core import SearchSpaceExhausted
2824
from ax.generators.random.base import RandomGenerator
@@ -43,7 +39,7 @@ def setUp(self) -> None:
4339
z = RangeParameter("z", ParameterType.FLOAT, lower=0, upper=5)
4440
self.parameters = [x, y, z]
4541
parameter_constraints: list[ParameterConstraint] = [
46-
OrderConstraint(x, y),
42+
ParameterConstraint(inequality="x <= y"),
4743
SumConstraint([x, z], False, 3.5),
4844
]
4945
self.search_space = SearchSpace(self.parameters, parameter_constraints)

ax/adapter/transforms/rounding.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import numpy as np
1414
import numpy.typing as npt
15-
from ax.core.parameter_constraint import OrderConstraint
1615
from ax.core.search_space import SearchSpace
1716
from ax.core.types import TParameterization
1817

@@ -56,13 +55,8 @@ def strict_onehot_round(x: npt.NDArray) -> npt.NDArray:
5655
def contains_constrained_integer(
5756
search_space: SearchSpace, transform_parameters: set[str]
5857
) -> bool:
59-
"""Check if any integer parameters are present in parameter_constraints.
60-
61-
Order constraints are ignored since strict rounding preserves ordering.
62-
"""
58+
"""Check if any integer parameters are present in parameter_constraints."""
6359
for constraint in search_space.parameter_constraints:
64-
if isinstance(constraint, OrderConstraint):
65-
continue
6660
constraint_params = set(constraint.constraint_dict.keys())
6761
if constraint_params.intersection(transform_parameters):
6862
return True

ax/adapter/transforms/tests/test_fixed_to_tunable.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
ParameterType,
1919
RangeParameter,
2020
)
21-
from ax.core.parameter_constraint import OrderConstraint
21+
from ax.core.parameter_constraint import ParameterConstraint
2222
from ax.core.search_space import SearchSpace
2323
from ax.utils.common.testutils import TestCase
2424
from ax.utils.testing.core_stubs import get_experiment_with_observations
@@ -120,11 +120,7 @@ def test_transform_search_space_with_constraints(self) -> None:
120120
]
121121
search_space_with_constraints = SearchSpace(
122122
parameters=parameters,
123-
parameter_constraints=[
124-
OrderConstraint(
125-
lower_parameter=parameters[0], upper_parameter=parameters[1]
126-
)
127-
],
123+
parameter_constraints=[ParameterConstraint(inequality="x <= y")],
128124
)
129125

130126
# Joint space with range parameter for 'y'

ax/adapter/transforms/tests/test_int_to_float_transform.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ax.adapter.transforms.int_to_float import IntToFloat
1616
from ax.core.observation import ObservationFeatures
1717
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
18-
from ax.core.parameter_constraint import OrderConstraint, SumConstraint
18+
from ax.core.parameter_constraint import ParameterConstraint, SumConstraint
1919
from ax.core.search_space import SearchSpace
2020
from ax.utils.common.testutils import TestCase
2121
from ax.utils.testing.core_stubs import get_experiment_with_observations
@@ -36,11 +36,7 @@ def setUp(self) -> None:
3636
]
3737
self.search_space = SearchSpace(
3838
parameters=parameters,
39-
parameter_constraints=[
40-
OrderConstraint(
41-
lower_parameter=parameters[0], upper_parameter=parameters[1]
42-
)
43-
],
39+
parameter_constraints=[ParameterConstraint(inequality="x <= a")],
4440
)
4541
self.t = IntToFloat(search_space=self.search_space)
4642
self.t2 = IntToFloat(

ax/core/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,7 @@
3131
ParameterType,
3232
RangeParameter,
3333
)
34-
from ax.core.parameter_constraint import (
35-
OrderConstraint,
36-
ParameterConstraint,
37-
SumConstraint,
38-
)
34+
from ax.core.parameter_constraint import ParameterConstraint, SumConstraint
3935
from ax.core.runner import Runner
4036
from ax.core.search_space import SearchSpace
4137
from ax.core.trial import Trial
@@ -57,7 +53,6 @@
5753
"ObjectiveThreshold",
5854
"ObservationFeatures",
5955
"OptimizationConfig",
60-
"OrderConstraint",
6156
"OutcomeConstraint",
6257
"Parameter",
6358
"ParameterConstraint",

ax/core/parameter_constraint.py

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -114,70 +114,6 @@ def _unique_id(self) -> str:
114114
return str(self)
115115

116116

117-
class OrderConstraint(ParameterConstraint):
118-
"""Constraint object for specifying one parameter to be smaller than another."""
119-
120-
_bound: float
121-
122-
def __init__(self, lower_parameter: Parameter, upper_parameter: Parameter) -> None:
123-
"""Initialize OrderConstraint
124-
125-
Args:
126-
lower_parameter: Parameter that should have the lower value.
127-
upper_parameter: Parameter that should have the higher value.
128-
129-
Note:
130-
The constraint p1 <= p2 can be expressed in matrix notation as
131-
[1, -1] * [p1, p2]^T <= 0.
132-
"""
133-
validate_constraint_parameters([lower_parameter, upper_parameter])
134-
135-
self._lower_parameter = lower_parameter
136-
self._upper_parameter = upper_parameter
137-
self._bound = 0.0
138-
139-
@property
140-
def lower_parameter(self) -> Parameter:
141-
"""Parameter with lower value."""
142-
return self._lower_parameter
143-
144-
@property
145-
def upper_parameter(self) -> Parameter:
146-
"""Parameter with higher value."""
147-
return self._upper_parameter
148-
149-
@property
150-
def parameters(self) -> list[Parameter]:
151-
"""Parameters."""
152-
return [self.lower_parameter, self.upper_parameter]
153-
154-
@property
155-
def constraint_dict(self) -> dict[str, float]:
156-
"""Weights on parameters for linear constraint representation."""
157-
return {self.lower_parameter.name: 1.0, self.upper_parameter.name: -1.0}
158-
159-
def clone(self) -> OrderConstraint:
160-
"""Clone."""
161-
return OrderConstraint(
162-
lower_parameter=self.lower_parameter.clone(),
163-
upper_parameter=self._upper_parameter.clone(),
164-
)
165-
166-
def clone_with_transformed_parameters(
167-
self, transformed_parameters: dict[str, Parameter]
168-
) -> OrderConstraint:
169-
"""Clone, but replace parameters with transformed versions."""
170-
return OrderConstraint(
171-
lower_parameter=transformed_parameters[self.lower_parameter.name],
172-
upper_parameter=transformed_parameters[self._upper_parameter.name],
173-
)
174-
175-
def __repr__(self) -> str:
176-
return "OrderConstraint({} <= {})".format(
177-
self.lower_parameter.name, self.upper_parameter.name
178-
)
179-
180-
181117
class SumConstraint(ParameterConstraint):
182118
"""Constraint on the sum of parameters being greater or less than a bound."""
183119

@@ -281,6 +217,7 @@ def validate_constraint_parameters(parameters: Sequence[Parameter]) -> None:
281217
if not isinstance(parameter, RangeParameter):
282218
raise ValueError(
283219
"All parameters in a parameter constraint must be RangeParameters."
220+
f"Found {parameter}"
284221
)
285222

286223
# Log parameters require a non-linear transformation, and Ax

ax/core/search_space.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
TParamValue,
2828
)
2929
from ax.core.parameter_constraint import (
30-
OrderConstraint,
3130
ParameterConstraint,
3231
SumConstraint,
32+
validate_constraint_parameters,
3333
)
3434
from ax.core.types import TParameterization
3535
from ax.exceptions.core import AxWarning, UnsupportedError, UserInputError
@@ -194,17 +194,16 @@ def set_parameter_constraints(
194194
# the matching name among the search space's parameters, so we
195195
# are not keeping two copies of the same parameter.
196196
for constraint in parameter_constraints:
197-
if isinstance(constraint, OrderConstraint):
198-
constraint._lower_parameter = self.parameters[
199-
constraint._lower_parameter.name
200-
]
201-
constraint._upper_parameter = self.parameters[
202-
constraint._upper_parameter.name
203-
]
204-
elif isinstance(constraint, SumConstraint):
197+
if isinstance(constraint, SumConstraint):
205198
for idx, parameter in enumerate(constraint.parameters):
206199
constraint.parameters[idx] = self.parameters[parameter.name]
207200

201+
validate_constraint_parameters(
202+
parameters=[
203+
self._parameters[name] for name in constraint.constraint_dict.keys()
204+
]
205+
)
206+
208207
self._parameter_constraints: list[ParameterConstraint] = parameter_constraints
209208

210209
def add_parameters(
@@ -534,9 +533,7 @@ def _validate_parameter_constraints(
534533
self, parameter_constraints: list[ParameterConstraint]
535534
) -> None:
536535
for constraint in parameter_constraints:
537-
if isinstance(constraint, OrderConstraint) or isinstance(
538-
constraint, SumConstraint
539-
):
536+
if isinstance(constraint, SumConstraint):
540537
for parameter in constraint.parameters:
541538
if parameter.name not in self._parameters.keys():
542539
raise ValueError(

ax/core/tests/test_parameter_constraint.py

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,9 @@
66

77
# pyre-strict
88

9-
from ax.core.parameter import (
10-
ChoiceParameter,
11-
FixedParameter,
12-
ParameterType,
13-
RangeParameter,
14-
)
9+
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
1510
from ax.core.parameter_constraint import (
1611
ComparisonOp,
17-
OrderConstraint,
1812
ParameterConstraint,
1913
SumConstraint,
2014
)
@@ -135,62 +129,6 @@ def test_Sortable(self) -> None:
135129
self.assertTrue(constraint1 < constraint2)
136130

137131

138-
class OrderConstraintTest(TestCase):
139-
def setUp(self) -> None:
140-
super().setUp()
141-
self.x = RangeParameter("x", ParameterType.INT, lower=0, upper=1)
142-
self.y = RangeParameter("y", ParameterType.INT, lower=0, upper=1)
143-
self.constraint = OrderConstraint(
144-
lower_parameter=self.x, upper_parameter=self.y
145-
)
146-
self.constraint_repr = "OrderConstraint(x <= y)"
147-
148-
def test_Properties(self) -> None:
149-
self.assertEqual(self.constraint.lower_parameter.name, "x")
150-
self.assertEqual(self.constraint.upper_parameter.name, "y")
151-
152-
def test_Repr(self) -> None:
153-
self.assertEqual(str(self.constraint), self.constraint_repr)
154-
155-
def test_Validate(self) -> None:
156-
self.assertTrue(self.constraint.check({"x": 0, "y": 1}))
157-
self.assertTrue(self.constraint.check({"x": 1, "y": 1}))
158-
self.assertFalse(self.constraint.check({"x": 1, "y": 0}))
159-
160-
def test_Clone(self) -> None:
161-
constraint_clone = self.constraint.clone()
162-
self.assertEqual(
163-
self.constraint.lower_parameter, constraint_clone.lower_parameter
164-
)
165-
166-
constraint_clone._lower_parameter = self.y
167-
self.assertNotEqual(
168-
self.constraint.lower_parameter, constraint_clone.lower_parameter
169-
)
170-
171-
def test_CloneWithTransformedParameters(self) -> None:
172-
constraint_clone = self.constraint.clone_with_transformed_parameters(
173-
transformed_parameters={p.name: p for p in self.constraint.parameters}
174-
)
175-
self.assertEqual(
176-
self.constraint.lower_parameter, constraint_clone.lower_parameter
177-
)
178-
179-
constraint_clone._lower_parameter = self.y
180-
self.assertNotEqual(
181-
self.constraint.lower_parameter, constraint_clone.lower_parameter
182-
)
183-
184-
def test_InvalidSetup(self) -> None:
185-
z = FixedParameter("z", ParameterType.INT, 0)
186-
with self.assertRaises(ValueError):
187-
self.constraint = OrderConstraint(lower_parameter=self.x, upper_parameter=z)
188-
189-
z = ChoiceParameter("z", ParameterType.STRING, ["a", "b", "c"])
190-
with self.assertRaises(ValueError):
191-
self.constraint = OrderConstraint(lower_parameter=self.x, upper_parameter=z)
192-
193-
194132
class SumConstraintTest(TestCase):
195133
def setUp(self) -> None:
196134
super().setUp()

0 commit comments

Comments
 (0)