Skip to content

Commit 8c8dc6c

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
Handle equality constraints in UnitX transform (#5179)
Summary: Pull Request resolved: #5179 Update the UnitX transform to preserve equality constraint type when transforming constraints to unit space. When constructing new constraints, check `c.is_equality` and use `ParameterConstraint(equality=...)` for equality constraints. The math is identical for both types (rescale coefficients by `(u - l)` and adjust bound by `w * l`). Other transforms (OneHot, IntToFloat, RemoveFixed, etc.) use `pc.clone()` which automatically preserves `is_equality` — no changes needed. Reviewed By: esantorella Differential Revision: D100256484 fbshipit-source-id: 5195d354a269bfa790080a92ee47c32e405f1e4a
1 parent 1c8c06e commit 8c8dc6c

3 files changed

Lines changed: 55 additions & 18 deletions

File tree

ax/adapter/transforms/tests/test_unit_x_transform.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,40 @@ def test_TransformSearchSpace(self) -> None:
133133
self.search_space_with_target.parameters["x"].target_value, 1.0
134134
)
135135

136+
def test_TransformSearchSpaceEqualityConstraints(self) -> None:
137+
# Verify that equality constraints are preserved and correctly
138+
# rescaled during UnitX transform_search_space.
139+
ss = SearchSpace(
140+
parameters=[
141+
RangeParameter(
142+
"x", lower=1, upper=3, parameter_type=ParameterType.FLOAT
143+
),
144+
RangeParameter(
145+
"y", lower=1, upper=2, parameter_type=ParameterType.FLOAT
146+
),
147+
],
148+
parameter_constraints=[
149+
ParameterConstraint(equality="-0.5*x + y == 0.5"),
150+
ParameterConstraint(inequality="-0.5*x + y <= 0.5"),
151+
],
152+
)
153+
t = UnitX(search_space=ss)
154+
ss = t.transform_search_space(ss)
155+
156+
# Both constraints have the same math; only the type differs.
157+
eq_c = ss.parameter_constraints[0]
158+
ineq_c = ss.parameter_constraints[1]
159+
160+
# Equality constraint preserved.
161+
self.assertTrue(eq_c.is_equality)
162+
self.assertEqual(eq_c.constraint_dict, {"x": -1.0, "y": 1.0})
163+
self.assertEqual(eq_c.bound, 0.0)
164+
165+
# Inequality constraint preserved.
166+
self.assertFalse(ineq_c.is_equality)
167+
self.assertEqual(ineq_c.constraint_dict, {"x": -1.0, "y": 1.0})
168+
self.assertEqual(ineq_c.bound, 0.0)
169+
136170
def test_transform_search_space_clears_digits(self) -> None:
137171
"""Test that digits is cleared during transform to avoid rounding
138172
in unit space. Regression test for a bug where digits=-3 (round to

ax/adapter/transforms/unit_x.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,14 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
103103
expr = " + ".join(
104104
f"{coeff} * {param}" for param, coeff in constraint_dict.items()
105105
)
106-
new_constraints.append(
107-
ParameterConstraint(
108-
inequality=f"{expr} <= {bound}",
106+
if c.is_equality:
107+
new_constraints.append(
108+
ParameterConstraint(equality=f"{expr} == {bound}")
109+
)
110+
else:
111+
new_constraints.append(
112+
ParameterConstraint(inequality=f"{expr} <= {bound}")
109113
)
110-
)
111114
search_space.set_parameter_constraints(new_constraints)
112115
return search_space
113116

ax/generators/tests/test_sobol.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -305,20 +305,20 @@ def test_SobolGeneratorFallbackToPolytopeSamplerWithEqAndFixedFeatures(
305305
)
306306
except Exception:
307307
pass
308-
if MockSampler.called:
309-
eq_arg = MockSampler.call_args.kwargs["equality_constraints"]
310-
self.assertIsNotNone(eq_arg)
311-
C, c = eq_arg
312-
# 1 from fixed features + 1 from parameter equality = 2 rows.
313-
self.assertEqual(C.shape[0], 2)
314-
self.assertEqual(C.shape[1], 11)
315-
# Fixed feature: x[10] = 1 (last row from fixed, first in sorted).
316-
self.assertEqual(C[0, 10].item(), 1.0)
317-
self.assertAlmostEqual(c[0].item(), 1.0)
318-
# Parameter equality: x[0] + x[1] = 0.3.
319-
self.assertEqual(C[1, 0].item(), 1.0)
320-
self.assertEqual(C[1, 1].item(), 1.0)
321-
self.assertAlmostEqual(c[1].item(), 0.3)
308+
self.assertTrue(MockSampler.called)
309+
eq_arg = MockSampler.call_args.kwargs["equality_constraints"]
310+
self.assertIsNotNone(eq_arg)
311+
C, c = eq_arg
312+
# 1 from fixed features + 1 from parameter equality = 2 rows.
313+
self.assertEqual(C.shape[0], 2)
314+
self.assertEqual(C.shape[1], 11)
315+
# Fixed feature: x[10] = 1 (last row from fixed, first in sorted).
316+
self.assertEqual(C[0, 10].item(), 1.0)
317+
self.assertAlmostEqual(c[0].item(), 1.0)
318+
# Parameter equality: x[0] + x[1] = 0.3.
319+
self.assertEqual(C[1, 0].item(), 1.0)
320+
self.assertEqual(C[1, 1].item(), 1.0)
321+
self.assertAlmostEqual(c[1].item(), 0.3)
322322

323323
def test_SobolGeneratorFallbackToPolytopeSamplerWithFixedParam(self) -> None:
324324
# Ten parameters with sum less than 1. In this example, the rejection

0 commit comments

Comments
 (0)