Skip to content

Commit 20b2a4e

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Replace SumConstraint with ParameterConstraint (#4660)
Summary: Pull Request resolved: #4660 Same as D88890559 but for SumConstraint Reviewed By: saitcakmak Differential Revision: D88897672
1 parent 88596a6 commit 20b2a4e

14 files changed

Lines changed: 59 additions & 320 deletions

File tree

ax/adapter/tests/test_base_adapter.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from ax.core.optimization_config import OptimizationConfig
4242
from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint
4343
from ax.core.parameter import ParameterType, RangeParameter
44-
from ax.core.parameter_constraint import SumConstraint
44+
from ax.core.parameter_constraint import ParameterConstraint
4545
from ax.core.search_space import SearchSpace
4646
from ax.core.types import TParameterization
4747
from ax.core.utils import get_target_trial_index
@@ -900,15 +900,7 @@ def test_set_model_space(self) -> None:
900900
trial.mark_completed()
901901
# Make search space with a parameter constraint
902902
ss = experiment.search_space.clone()
903-
ss.set_parameter_constraints(
904-
[
905-
SumConstraint(
906-
parameters=list(ss.parameters.values()),
907-
is_upper_bound=True,
908-
bound=30.0,
909-
)
910-
]
911-
)
903+
ss.set_parameter_constraints([ParameterConstraint(inequality="x1 + x2 <= 30")])
912904

913905
# Check that SQ and custom are OOD
914906
m = Adapter(

ax/adapter/tests/test_random_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ax.core.metric import Metric
1919
from ax.core.observation import ObservationFeatures
2020
from ax.core.parameter import ParameterType, RangeParameter
21-
from ax.core.parameter_constraint import ParameterConstraint, SumConstraint
21+
from ax.core.parameter_constraint import ParameterConstraint
2222
from ax.core.search_space import SearchSpace
2323
from ax.exceptions.core import SearchSpaceExhausted
2424
from ax.generators.random.base import RandomGenerator
@@ -40,7 +40,7 @@ def setUp(self) -> None:
4040
self.parameters = [x, y, z]
4141
parameter_constraints: list[ParameterConstraint] = [
4242
ParameterConstraint(inequality="x <= y"),
43-
SumConstraint([x, z], False, 3.5),
43+
ParameterConstraint(inequality="x + z >= 3.5"),
4444
]
4545
self.search_space = SearchSpace(self.parameters, parameter_constraints)
4646
self.experiment = Experiment(search_space=self.search_space)

ax/adapter/transforms/tests/test_int_to_float_transform.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ax.adapter.transforms.int_to_float import IntToFloat
1616
from ax.core.observation import ObservationFeatures
1717
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
18-
from ax.core.parameter_constraint import ParameterConstraint, SumConstraint
18+
from ax.core.parameter_constraint import ParameterConstraint
1919
from ax.core.search_space import SearchSpace
2020
from ax.utils.common.testutils import TestCase
2121
from ax.utils.testing.core_stubs import get_experiment_with_observations
@@ -239,11 +239,7 @@ def test_RoundingWithConstrainedIntRanges(self) -> None:
239239
]
240240
constrained_int_search_space = SearchSpace(
241241
parameters=parameters,
242-
parameter_constraints=[
243-
# pyre-fixme[6]: For 1st param expected `List[Parameter]` but got
244-
# `List[RangeParameter]`.
245-
SumConstraint(parameters=parameters, is_upper_bound=True, bound=5)
246-
],
242+
parameter_constraints=[ParameterConstraint(inequality="x + y <= 5")],
247243
)
248244
t = IntToFloat(search_space=constrained_int_search_space)
249245
self.assertEqual(t.rounding, "randomized")
@@ -288,11 +284,7 @@ def test_RoundingWithImpossiblyConstrainedIntRanges(self) -> None:
288284
]
289285
constrained_int_search_space = SearchSpace(
290286
parameters=parameters,
291-
parameter_constraints=[
292-
# pyre-fixme[6]: For 1st param expected `List[Parameter]` but got
293-
# `List[RangeParameter]`.
294-
SumConstraint(parameters=parameters, is_upper_bound=True, bound=3)
295-
],
287+
parameter_constraints=[ParameterConstraint(inequality="x + y <= 3")],
296288
)
297289
t = IntToFloat(search_space=constrained_int_search_space)
298290
self.assertEqual(t.rounding, "randomized")

ax/core/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
ParameterType,
3232
RangeParameter,
3333
)
34-
from ax.core.parameter_constraint import ParameterConstraint, SumConstraint
34+
from ax.core.parameter_constraint import ParameterConstraint
3535
from ax.core.runner import Runner
3636
from ax.core.search_space import SearchSpace
3737
from ax.core.trial import Trial
@@ -61,6 +61,5 @@
6161
"Runner",
6262
"SearchSpace",
6363
"SimpleExperiment",
64-
"SumConstraint",
6564
"Trial",
6665
]

ax/core/parameter_constraint.py

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from typing import Sequence
1212

1313
from ax.core.parameter import Parameter, RangeParameter
14-
from ax.core.types import ComparisonOp
1514
from ax.exceptions.core import UserInputError
1615
from ax.utils.common.base import SortableBase
1716
from ax.utils.common.string_utils import unsanitize_name
@@ -114,92 +113,6 @@ def _unique_id(self) -> str:
114113
return str(self)
115114

116115

117-
class SumConstraint(ParameterConstraint):
118-
"""Constraint on the sum of parameters being greater or less than a bound."""
119-
120-
def __init__(
121-
self, parameters: list[Parameter], is_upper_bound: bool, bound: float
122-
) -> None:
123-
"""Initialize SumConstraint
124-
125-
Args:
126-
parameters: List of parameters whose sum to constrain on.
127-
is_upper_bound: Whether the bound is an upper or lower bound on the sum.
128-
bound: The bound on the sum.
129-
"""
130-
validate_constraint_parameters(parameters)
131-
132-
self._parameters = parameters
133-
self._is_upper_bound: bool = is_upper_bound
134-
self._parameter_names: list[str] = [parameter.name for parameter in parameters]
135-
self._bound: float = self._inequality_weight * bound
136-
self._constraint_dict: dict[str, float] = {
137-
name: self._inequality_weight for name in self._parameter_names
138-
}
139-
140-
@property
141-
def parameters(self) -> list[Parameter]:
142-
"""Parameters."""
143-
return self._parameters
144-
145-
@property
146-
def constraint_dict(self) -> dict[str, float]:
147-
"""Weights on parameters for linear constraint representation."""
148-
return self._constraint_dict
149-
150-
@property
151-
def op(self) -> ComparisonOp:
152-
"""Whether the sum is constrained by a <= or >= inequality."""
153-
return ComparisonOp.LEQ if self._is_upper_bound else ComparisonOp.GEQ
154-
155-
@property
156-
def is_upper_bound(self) -> bool:
157-
"""Whether the bound is an upper or lower bound on the sum."""
158-
return self._is_upper_bound
159-
160-
def clone(self) -> SumConstraint:
161-
"""Clone.
162-
163-
To use the same constraint, we need to reconstruct the original bound.
164-
We do this by re-applying the original bound weighting.
165-
"""
166-
return SumConstraint(
167-
parameters=[p.clone() for p in self._parameters],
168-
is_upper_bound=self._is_upper_bound,
169-
bound=self._inequality_weight * self._bound,
170-
)
171-
172-
def clone_with_transformed_parameters(
173-
self, transformed_parameters: dict[str, Parameter]
174-
) -> SumConstraint:
175-
"""Clone, but replace parameters with transformed versions."""
176-
return SumConstraint(
177-
parameters=[transformed_parameters[p.name] for p in self._parameters],
178-
is_upper_bound=self._is_upper_bound,
179-
bound=self._inequality_weight * self._bound,
180-
)
181-
182-
@property
183-
def _inequality_weight(self) -> float:
184-
"""Multiplier of all terms in the inequality.
185-
186-
If the constraint is an upper bound, it is v1 + v2 ... v_n <= b
187-
If the constraint is an lower bound, it is -v1 + -v2 ... -v_n <= -b
188-
This property returns 1 or -1 depending on the scenario
189-
"""
190-
return 1.0 if self._is_upper_bound else -1.0
191-
192-
def __repr__(self) -> str:
193-
symbol = ">=" if self.op == ComparisonOp.GEQ else "<="
194-
return (
195-
"SumConstraint("
196-
+ " + ".join(self._parameter_names)
197-
+ " {} {})".format(
198-
symbol, self._bound if self.op == ComparisonOp.LEQ else -self._bound
199-
)
200-
)
201-
202-
203116
def validate_constraint_parameters(parameters: Sequence[Parameter]) -> None:
204117
"""Basic validation of parameters used in a constraint.
205118

ax/core/search_space.py

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
)
2929
from ax.core.parameter_constraint import (
3030
ParameterConstraint,
31-
SumConstraint,
3231
validate_constraint_parameters,
3332
)
3433
from ax.core.types import TParameterization
@@ -190,22 +189,15 @@ def set_parameter_constraints(
190189
# Validate that all parameters in constraints are in search
191190
# space already.
192191
self._validate_parameter_constraints(parameter_constraints)
193-
# Set the parameter on the constraint to be the parameter by
194-
# the matching name among the search space's parameters, so we
195-
# are not keeping two copies of the same parameter.
196-
for constraint in parameter_constraints:
197-
if isinstance(constraint, SumConstraint):
198-
for idx, parameter in enumerate(constraint.parameters):
199-
constraint.parameters[idx] = self.parameters[parameter.name]
192+
self._parameter_constraints: list[ParameterConstraint] = parameter_constraints
200193

194+
for constraint in self.parameter_constraints:
201195
validate_constraint_parameters(
202196
parameters=[
203197
self._parameters[name] for name in constraint.constraint_dict.keys()
204198
]
205199
)
206200

207-
self._parameter_constraints: list[ParameterConstraint] = parameter_constraints
208-
209201
def add_parameters(
210202
self,
211203
parameters: Sequence[Parameter],
@@ -533,29 +525,17 @@ def _validate_parameter_constraints(
533525
self, parameter_constraints: list[ParameterConstraint]
534526
) -> None:
535527
for constraint in parameter_constraints:
536-
if isinstance(constraint, SumConstraint):
537-
for parameter in constraint.parameters:
538-
if parameter.name not in self._parameters.keys():
539-
raise ValueError(
540-
f"`{parameter.name}` does not exist in search space."
541-
)
542-
if parameter != self._parameters[parameter.name]:
543-
raise ValueError(
544-
f"Parameter constraint's definition of '{parameter.name}' "
545-
"does not match the SearchSpace's definition"
546-
)
547-
else:
548-
for parameter_name in constraint.constraint_dict.keys():
549-
p = self._parameters.get(parameter_name)
550-
if p is None:
551-
raise ValueError(
552-
f"`{parameter_name}` does not exist in search space."
553-
)
554-
elif isinstance(p, DerivedParameter):
555-
raise ValueError(
556-
"Parameter constraints cannot be used with derived "
557-
"parameters."
558-
)
528+
for parameter_name in constraint.constraint_dict.keys():
529+
p = self._parameters.get(parameter_name)
530+
if p is None:
531+
raise ValueError(
532+
f"`{parameter_name}` does not exist in search space."
533+
)
534+
elif isinstance(p, DerivedParameter):
535+
raise ValueError(
536+
"Parameter constraints cannot be used with derived "
537+
"parameters."
538+
)
559539

560540
def _validate_hierarchical_structure(self) -> None:
561541
"""Validate the structure of this hierarchical search space, ensuring that all

ax/core/tests/test_parameter_constraint.py

Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,7 @@
66

77
# pyre-strict
88

9-
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
10-
from ax.core.parameter_constraint import (
11-
ComparisonOp,
12-
ParameterConstraint,
13-
SumConstraint,
14-
)
9+
from ax.core.parameter_constraint import ParameterConstraint
1510
from ax.exceptions.core import UserInputError
1611
from ax.utils.common.testutils import TestCase
1712

@@ -127,68 +122,3 @@ def test_Sortable(self) -> None:
127122
inequality="2 * x - 3 * y <= 6.0",
128123
)
129124
self.assertTrue(constraint1 < constraint2)
130-
131-
132-
class SumConstraintTest(TestCase):
133-
def setUp(self) -> None:
134-
super().setUp()
135-
self.x = RangeParameter("x", ParameterType.INT, lower=-5, upper=5)
136-
self.y = RangeParameter("y", ParameterType.INT, lower=-5, upper=5)
137-
self.constraint1 = SumConstraint(
138-
parameters=[self.x, self.y], is_upper_bound=True, bound=5
139-
)
140-
self.constraint2 = SumConstraint(
141-
parameters=[self.x, self.y], is_upper_bound=False, bound=-5
142-
)
143-
144-
self.constraint_repr1 = "SumConstraint(x + y <= 5.0)"
145-
self.constraint_repr2 = "SumConstraint(x + y >= -5.0)"
146-
147-
def test_BadConstruct(self) -> None:
148-
with self.assertRaises(ValueError):
149-
SumConstraint(parameters=[self.x, self.x], is_upper_bound=False, bound=-5.0)
150-
z = ChoiceParameter("z", ParameterType.STRING, ["a", "b", "c"])
151-
with self.assertRaises(ValueError):
152-
# pyre-fixme[16]: `SumConstraintTest` has no attribute `constraint`.
153-
self.constraint = SumConstraint(
154-
parameters=[self.x, z], is_upper_bound=False, bound=-5.0
155-
)
156-
157-
def test_Properties(self) -> None:
158-
self.assertEqual(self.constraint1.op, ComparisonOp.LEQ)
159-
self.assertTrue(self.constraint1._is_upper_bound)
160-
161-
self.assertEqual(self.constraint2.op, ComparisonOp.GEQ)
162-
self.assertFalse(self.constraint2._is_upper_bound)
163-
164-
def test_Repr(self) -> None:
165-
self.assertEqual(str(self.constraint1), self.constraint_repr1)
166-
self.assertEqual(str(self.constraint2), self.constraint_repr2)
167-
168-
def test_Validate(self) -> None:
169-
self.assertTrue(self.constraint1.check({"x": 1, "y": 4}))
170-
self.assertTrue(self.constraint1.check({"x": 4, "y": 1}))
171-
self.assertFalse(self.constraint1.check({"x": 1, "y": 5}))
172-
173-
self.assertTrue(self.constraint2.check({"x": -4, "y": -1}))
174-
self.assertTrue(self.constraint2.check({"x": -1, "y": -4}))
175-
self.assertFalse(self.constraint2.check({"x": -5, "y": -1}))
176-
177-
def test_Clone(self) -> None:
178-
constraint_clone = self.constraint1.clone()
179-
self.assertEqual(self.constraint1.bound, constraint_clone.bound)
180-
181-
constraint_clone._bound = 7.0
182-
self.assertNotEqual(self.constraint1.bound, constraint_clone.bound)
183-
184-
constraint_clone_2 = self.constraint2.clone()
185-
self.assertEqual(self.constraint2.bound, constraint_clone_2.bound)
186-
187-
def test_CloneWithTransformedParameters(self) -> None:
188-
constraint_clone = self.constraint1.clone_with_transformed_parameters(
189-
transformed_parameters={p.name: p for p in self.constraint1.parameters}
190-
)
191-
self.assertEqual(self.constraint1.bound, constraint_clone.bound)
192-
193-
constraint_clone._bound = 7.0
194-
self.assertNotEqual(self.constraint1.bound, constraint_clone.bound)

ax/core/tests/test_search_space.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
ParameterType,
2121
RangeParameter,
2222
)
23-
from ax.core.parameter_constraint import ParameterConstraint, SumConstraint
23+
from ax.core.parameter_constraint import ParameterConstraint
2424
from ax.core.search_space import SearchSpace, SearchSpaceDigest
2525
from ax.core.types import TParameterization
2626
from ax.exceptions.core import UserInputError
@@ -150,9 +150,7 @@ def test_Repr(self) -> None:
150150
self.assertEqual(str(self.ss1), self.ss1_repr)
151151

152152
def test_Setter(self) -> None:
153-
new_c = SumConstraint(
154-
parameters=[self.a, self.b], is_upper_bound=True, bound=10
155-
)
153+
new_c = ParameterConstraint(inequality="a + b <= 10")
156154
self.ss2.add_parameter_constraints([new_c])
157155
self.assertEqual(len(self.ss2.parameter_constraints), 2)
158156

0 commit comments

Comments
 (0)