Skip to content

Commit 1c8c06e

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
Support equality constraints in random generator polytope sampling (#5178)
Summary: Pull Request resolved: #5178 Thread equality constraints from the search space through the random adapter to the `HitAndRunPolytopeSampler`. When rejection sampling falls back to polytope sampling, equality constraints from both `fixed_features` and parameter constraints are combined into a single `(C, c)` matrix. - Update `RandomAdapter._gen` to extract equality constraints via `extract_equality_constraints`. - Add `equality_constraints` parameter to `RandomGenerator.gen()`. - Add `_combine_equality_constraints` method that merges fixed-feature-based and parameter-based equality constraints for the polytope sampler. Reviewed By: bletham Differential Revision: D100256488 fbshipit-source-id: c74cc6a82b7e49b2526f9b95cb6c9cf0d10859e3
1 parent 58cff45 commit 1c8c06e

7 files changed

Lines changed: 246 additions & 1 deletion

File tree

ax/adapter/random.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import numpy as np
1313
from ax.adapter.adapter_utils import (
14+
extract_equality_constraints,
1415
extract_inequality_constraints,
1516
extract_search_space_digest,
1617
get_fixed_features,
@@ -95,6 +96,9 @@ def _gen(
9596
linear_constraints = extract_inequality_constraints(
9697
search_space.parameter_constraints, self.parameters
9798
)
99+
equality_constraints_np = extract_equality_constraints(
100+
search_space.parameter_constraints, self.parameters
101+
)
98102
# Extract generated points.
99103
# For normal generators these are used to deduplicate against.
100104
# For in-sample generators (LILO labeling) they are the selection
@@ -177,6 +181,7 @@ def _gen(
177181
n=n,
178182
search_space_digest=search_space_digest,
179183
linear_constraints=linear_constraints,
184+
equality_constraints=equality_constraints_np,
180185
fixed_features=fixed_features_dict,
181186
model_gen_options=model_gen_options,
182187
rounding_func=transform_callback(self.parameters, self.transforms),

ax/adapter/tests/test_random_adapter.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,65 @@ def test_gen_w_constraints(self) -> None:
109109
self.assertEqual(obsf[1].parameters, {"x": 3.0, "y": 4.0, "z": 3.0})
110110
self.assertTrue(np.array_equal(gen_results.weights, np.array([1.0, 2.0])))
111111

112+
def test_gen_w_equality_constraints(self) -> None:
113+
# Verify that equality constraints from the search space are extracted
114+
# and passed through to the generator's gen() call.
115+
x = RangeParameter("x", ParameterType.FLOAT, lower=0, upper=1)
116+
y = RangeParameter("y", ParameterType.FLOAT, lower=0, upper=1)
117+
z = RangeParameter("z", ParameterType.FLOAT, lower=0, upper=1)
118+
parameter_constraints = [
119+
ParameterConstraint(equality="x + y == 0.5"),
120+
]
121+
search_space = SearchSpace([x, y, z], parameter_constraints)
122+
experiment = Experiment(search_space=search_space)
123+
adapter = RandomAdapter(experiment=experiment, generator=RandomGenerator())
124+
with mock.patch.object(
125+
adapter.generator,
126+
"gen",
127+
return_value=(
128+
np.array([[0.2, 0.3, 0.4]]),
129+
np.array([1.0]),
130+
),
131+
) as mock_gen:
132+
adapter._gen(
133+
n=1,
134+
search_space=search_space,
135+
pending_observations={},
136+
fixed_features=ObservationFeatures({}),
137+
optimization_config=None,
138+
model_gen_options=self.model_gen_options,
139+
)
140+
gen_args = mock_gen.mock_calls[0][2]
141+
eq_constraints = gen_args["equality_constraints"]
142+
self.assertIsNotNone(eq_constraints)
143+
A, b = eq_constraints
144+
# x + y = 0.5 => A = [[1, 1, 0]], b = [[0.5]]
145+
self.assertTrue(np.array_equal(A, np.array([[1.0, 1.0, 0.0]])))
146+
self.assertTrue(np.array_equal(b, np.array([[0.5]])))
147+
148+
def test_gen_no_equality_constraints(self) -> None:
149+
# Verify that equality_constraints is None when there are no equality
150+
# constraints on the search space.
151+
adapter = RandomAdapter(experiment=self.experiment, generator=RandomGenerator())
152+
with mock.patch.object(
153+
adapter.generator,
154+
"gen",
155+
return_value=(
156+
np.array([[0.5, 1.5, 2.5]]),
157+
np.array([1.0]),
158+
),
159+
) as mock_gen:
160+
adapter._gen(
161+
n=1,
162+
search_space=self.search_space,
163+
pending_observations={},
164+
fixed_features=ObservationFeatures({}),
165+
optimization_config=None,
166+
model_gen_options=self.model_gen_options,
167+
)
168+
gen_args = mock_gen.mock_calls[0][2]
169+
self.assertIsNone(gen_args["equality_constraints"])
170+
112171
def test_gen_simple(self) -> None:
113172
# Test with no constraints, no fixed feature, no pending observations
114173
search_space = SearchSpace(self.parameters[:2])

ax/generators/random/base.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def gen(
113113
n: int,
114114
search_space_digest: SearchSpaceDigest,
115115
linear_constraints: tuple[npt.NDArray, npt.NDArray] | None = None,
116+
equality_constraints: tuple[npt.NDArray, npt.NDArray] | None = None,
116117
fixed_features: dict[int, float] | None = None,
117118
model_gen_options: TConfig | None = None,
118119
rounding_func: Callable[[npt.NDArray], npt.NDArray] | None = None,
@@ -127,6 +128,9 @@ def gen(
127128
linear_constraints: A tuple of (A, b). For k linear constraints on
128129
d-dimensional x, A is (k x d) and b is (k x 1) such that
129130
A x <= b.
131+
equality_constraints: A tuple of (A, b). For k equality constraints
132+
on d-dimensional x, A is (k x d) and b is (k x 1) such that
133+
A x = b.
130134
fixed_features: A map {feature_index: value} for features that
131135
should be fixed to a particular value during generation.
132136
model_gen_options: A config dictionary that is passed along to the
@@ -205,9 +209,10 @@ def gen(
205209
inequality_constraints=self._convert_inequality_constraints(
206210
linear_constraints,
207211
),
208-
equality_constraints=self._convert_equality_constraints(
212+
equality_constraints=self._combine_equality_constraints(
209213
d=len(search_space_digest.bounds),
210214
fixed_features=fixed_features,
215+
equality_constraints=equality_constraints,
211216
),
212217
bounds=self._convert_bounds(bounds=search_space_digest.bounds),
213218
interior_point=interior_point,
@@ -353,6 +358,43 @@ def _convert_equality_constraints(
353358
constraint_matrix[index, fixed_indices[index]] = 1.0
354359
return constraint_matrix, fixed_vals
355360

361+
def _combine_equality_constraints(
362+
self,
363+
d: int,
364+
fixed_features: dict[int, float] | None,
365+
equality_constraints: tuple[npt.NDArray, npt.NDArray] | None = None,
366+
) -> tuple[Tensor, Tensor] | None:
367+
"""Combine fixed-feature equality constraints with parameter equality
368+
constraints into a single (C, c) matrix for the polytope sampler.
369+
370+
Args:
371+
d: Dimension of samples.
372+
fixed_features: A map {feature_index: value} for features that
373+
should be fixed to a particular value during generation.
374+
equality_constraints: A tuple of (A, b) NumPy arrays from
375+
``extract_equality_constraints``.
376+
377+
Returns:
378+
Optional 2-element tuple containing C and c such that Cx = c.
379+
"""
380+
fixed_eq = self._convert_equality_constraints(
381+
d=d, fixed_features=fixed_features
382+
)
383+
param_eq = None
384+
if equality_constraints is not None:
385+
A = torch.as_tensor(equality_constraints[0], dtype=torch.double)
386+
b = torch.as_tensor(equality_constraints[1], dtype=torch.double).squeeze(-1)
387+
param_eq = (A, b)
388+
389+
if fixed_eq is None and param_eq is None:
390+
return None
391+
if fixed_eq is not None and param_eq is not None:
392+
return (
393+
torch.cat([fixed_eq[0], param_eq[0]], dim=0),
394+
torch.cat([fixed_eq[1], param_eq[1]], dim=0),
395+
)
396+
return fixed_eq if fixed_eq is not None else param_eq
397+
356398
def _convert_bounds(self, bounds: list[tuple[float, float]]) -> Tensor | None:
357399
"""Helper method to convert bounds list used by the rejectionsampler to the
358400
tensor format required for the polytope sampler.

ax/generators/random/in_sample.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def gen(
3737
n: int,
3838
search_space_digest: SearchSpaceDigest,
3939
linear_constraints: tuple[npt.NDArray, npt.NDArray] | None = None,
40+
equality_constraints: tuple[npt.NDArray, npt.NDArray] | None = None,
4041
fixed_features: dict[int, float] | None = None,
4142
model_gen_options: TConfig | None = None,
4243
rounding_func: Callable[[npt.NDArray], npt.NDArray] | None = None,
@@ -49,6 +50,7 @@ def gen(
4950
search_space_digest: A ``SearchSpaceDigest`` object containing
5051
metadata on the features in the datasets.
5152
linear_constraints: Not used. Accepted for interface compatibility.
53+
equality_constraints: Not used. Accepted for interface compatibility.
5254
fixed_features: Not used. Accepted for interface compatibility.
5355
model_gen_options: Not used. Accepted for interface compatibility.
5456
rounding_func: Not used. Accepted for interface compatibility.

ax/generators/random/sobol.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def gen(
8181
n: int,
8282
search_space_digest: SearchSpaceDigest,
8383
linear_constraints: tuple[npt.NDArray, npt.NDArray] | None = None,
84+
equality_constraints: tuple[npt.NDArray, npt.NDArray] | None = None,
8485
fixed_features: dict[int, float] | None = None,
8586
model_gen_options: TConfig | None = None,
8687
rounding_func: Callable[[npt.NDArray], npt.NDArray] | None = None,
@@ -95,6 +96,9 @@ def gen(
9596
linear_constraints: A tuple of (A, b). For k linear constraints on
9697
d-dimensional x, A is (k x d) and b is (k x 1) such that
9798
A x <= b.
99+
equality_constraints: A tuple of (A, b). For k equality constraints
100+
on d-dimensional x, A is (k x d) and b is (k x 1) such that
101+
A x = b.
98102
fixed_features: A map {feature_index: value} for features that
99103
should be fixed to a particular value during generation.
100104
rounding_func: A function that rounds an optimization result
@@ -117,6 +121,7 @@ def gen(
117121
n=n,
118122
search_space_digest=search_space_digest,
119123
linear_constraints=linear_constraints,
124+
equality_constraints=equality_constraints,
120125
fixed_features=fixed_features,
121126
model_gen_options=model_gen_options,
122127
rounding_func=rounding_func,

ax/generators/tests/test_random.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,62 @@ def test_ConvertEqualityConstraints(self) -> None:
6161
self.assertEqual(C_comparison.any(), True)
6262
self.assertEqual(self.random_model._convert_equality_constraints(d, None), None)
6363

64+
def test_CombineEqualityConstraints(self) -> None:
65+
d = 4
66+
# Both None: returns None.
67+
self.assertIsNone(
68+
self.random_model._combine_equality_constraints(
69+
d=d, fixed_features=None, equality_constraints=None
70+
)
71+
)
72+
73+
# Only fixed_features: returns the fixed-feature constraints.
74+
fixed_features = {1: 0.5, 3: 0.7}
75+
C, c = none_throws(
76+
self.random_model._combine_equality_constraints(
77+
d=d, fixed_features=fixed_features, equality_constraints=None
78+
)
79+
)
80+
self.assertEqual(C.shape, (2, d))
81+
self.assertEqual(c.shape, (2,))
82+
self.assertEqual(C[0, 1].item(), 1.0)
83+
self.assertEqual(C[1, 3].item(), 1.0)
84+
self.assertAlmostEqual(c[0].item(), 0.5)
85+
self.assertAlmostEqual(c[1].item(), 0.7)
86+
87+
# Only equality_constraints: returns the parameter constraints.
88+
A_np = np.array([[1.0, 1.0, 0.0, 0.0]])
89+
b_np = np.array([[2.0]])
90+
C, c = none_throws(
91+
self.random_model._combine_equality_constraints(
92+
d=d,
93+
fixed_features=None,
94+
equality_constraints=(A_np, b_np),
95+
)
96+
)
97+
self.assertEqual(C.shape, (1, d))
98+
self.assertEqual(c.shape, (1,))
99+
self.assertTrue(torch.equal(C, torch.tensor([[1.0, 1.0, 0.0, 0.0]])))
100+
self.assertAlmostEqual(c[0].item(), 2.0)
101+
102+
# Both present: concatenates fixed-feature and parameter constraints.
103+
C, c = none_throws(
104+
self.random_model._combine_equality_constraints(
105+
d=d,
106+
fixed_features=fixed_features,
107+
equality_constraints=(A_np, b_np),
108+
)
109+
)
110+
# 2 from fixed_features + 1 from equality_constraints = 3 rows.
111+
self.assertEqual(C.shape, (3, d))
112+
self.assertEqual(c.shape, (3,))
113+
# First two rows are from fixed_features (sorted by key: 1, 3).
114+
self.assertEqual(C[0, 1].item(), 1.0)
115+
self.assertEqual(C[1, 3].item(), 1.0)
116+
# Third row is from the parameter equality constraint.
117+
self.assertTrue(torch.equal(C[2], torch.tensor([1.0, 1.0, 0.0, 0.0])))
118+
self.assertAlmostEqual(c[2].item(), 2.0)
119+
64120
def test_ConvertInequalityConstraints(self) -> None:
65121
A = np.array([[1, 2], [3, 4]])
66122
b = np.array([[5], [6]])

ax/generators/tests/test_sobol.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from unittest import mock
1010

1111
import numpy as np
12+
import torch
1213
from ax.core.search_space import SearchSpaceDigest
1314
from ax.exceptions.core import SearchSpaceExhausted
1415
from ax.generators.random.sobol import SobolGenerator
@@ -244,6 +245,81 @@ def rounding_function(x: npt.NDArray) -> npt.NDArray:
244245

245246
rounding_func.assert_called()
246247

248+
def test_SobolGeneratorFallbackToPolytopeSamplerWithEqualityConstraints(
249+
self,
250+
) -> None:
251+
# Ten parameters with sum <= 1 (forces polytope fallback) plus
252+
# an equality constraint x0 + x1 = 0.4 to verify equality constraints
253+
# are threaded through to the polytope sampler.
254+
generator = SobolGenerator(seed=0, fallback_to_sample_polytope=True)
255+
ssd = self._create_ssd(n_tunable=10, n_fixed=0)
256+
A_ineq = np.ones((1, 10))
257+
b_ineq = np.array([1]).reshape((1, 1))
258+
# Equality constraint: x0 + x1 = 0.4
259+
A_eq = np.zeros((1, 10))
260+
A_eq[0, 0] = 1.0
261+
A_eq[0, 1] = 1.0
262+
b_eq = np.array([[0.4]])
263+
generated_points, _ = generator.gen(
264+
n=3,
265+
search_space_digest=ssd,
266+
linear_constraints=(A_ineq, b_ineq),
267+
equality_constraints=(A_eq, b_eq),
268+
rounding_func=lambda x: x,
269+
)
270+
self.assertEqual(np.shape(generated_points), (3, 10))
271+
# Inequality constraint satisfied.
272+
self.assertTrue(np.all(generated_points @ A_ineq.T <= b_ineq + 1e-6))
273+
# Equality constraint satisfied: x0 + x1 ≈ 0.4.
274+
eq_vals = generated_points @ A_eq.T
275+
np.testing.assert_allclose(eq_vals, 0.4, atol=1e-6)
276+
277+
def test_SobolGeneratorFallbackToPolytopeSamplerWithEqAndFixedFeatures(
278+
self,
279+
) -> None:
280+
# Verify that equality constraints and fixed features are combined
281+
# correctly and passed to the polytope sampler.
282+
generator = SobolGenerator(seed=0, fallback_to_sample_polytope=True)
283+
ssd = self._create_ssd(n_tunable=10, n_fixed=1)
284+
A_ineq = np.insert(np.ones((1, 10)), 10, 0, axis=1)
285+
b_ineq = np.array([1]).reshape((1, 1))
286+
# Equality constraint: x0 + x1 = 0.3.
287+
A_eq = np.zeros((1, 11))
288+
A_eq[0, 0] = 1.0
289+
A_eq[0, 1] = 1.0
290+
b_eq = np.array([[0.3]])
291+
292+
with mock.patch(
293+
"ax.generators.random.base.HitAndRunPolytopeSampler"
294+
) as MockSampler:
295+
mock_instance = MockSampler.return_value
296+
mock_instance.draw.return_value = torch.rand(1, 11, dtype=torch.double)
297+
try:
298+
generator.gen(
299+
n=1,
300+
search_space_digest=ssd,
301+
linear_constraints=(A_ineq, b_ineq),
302+
equality_constraints=(A_eq, b_eq),
303+
fixed_features={10: 1},
304+
rounding_func=lambda x: x,
305+
)
306+
except Exception:
307+
pass
308+
if MockSampler.called:
309+
eq_arg = MockSampler.call_args.kwargs["equality_constraints"]
310+
self.assertIsNotNone(eq_arg)
311+
C, c = eq_arg
312+
# 1 from fixed features + 1 from parameter equality = 2 rows.
313+
self.assertEqual(C.shape[0], 2)
314+
self.assertEqual(C.shape[1], 11)
315+
# Fixed feature: x[10] = 1 (last row from fixed, first in sorted).
316+
self.assertEqual(C[0, 10].item(), 1.0)
317+
self.assertAlmostEqual(c[0].item(), 1.0)
318+
# Parameter equality: x[0] + x[1] = 0.3.
319+
self.assertEqual(C[1, 0].item(), 1.0)
320+
self.assertEqual(C[1, 1].item(), 1.0)
321+
self.assertAlmostEqual(c[1].item(), 0.3)
322+
247323
def test_SobolGeneratorFallbackToPolytopeSamplerWithFixedParam(self) -> None:
248324
# Ten parameters with sum less than 1. In this example, the rejection
249325
# sampler gives a search space exhausted error. Testing fallback to

0 commit comments

Comments
 (0)