Skip to content

Commit 47defa1

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
Enforce equality constraints in random generator rejection sampling (#5182)
Summary: Pull Request resolved: #5182 Random generators (Sobol, etc.) were not respecting equality constraints during candidate generation. Two fixes: 1. When equality constraints are present, skip rejection sampling entirely and go straight to polytope sampling. Unconstrained random samples have probability zero of satisfying continuous equality constraints, so rejection sampling would always exhaust max_draws. 2. Add `equality_constraints` parameter to `rejection_sample` and `check_param_constraints` so that post-rounding feasibility checks also validate equality constraints (important when the polytope sampler fallback uses rejection_sample for deduplication). Reviewed By: esantorella Differential Revision: D100256485 fbshipit-source-id: 540e45ad2fd4852ecff7dc7500930758d4fed979
1 parent cd36879 commit 47defa1

2 files changed

Lines changed: 62 additions & 8 deletions

File tree

ax/generators/random/base.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ def gen(
154154
max_draws = model_gen_options.get("max_rs_draws", DEFAULT_MAX_RS_DRAWS)
155155
max_draws = int(assert_is_instance_of_tuple(max_draws, (int, float)))
156156
try:
157+
# With equality constraints, unconstrained sampling has probability
158+
# zero of producing feasible points, so skip straight to polytope
159+
# sampling.
160+
if equality_constraints is not None:
161+
raise SearchSpaceExhausted(
162+
"Equality constraints require polytope sampling."
163+
)
157164
# Always rejection sample, but this only rejects if there are
158165
# constraints or actual duplicates and deduplicate is specified.
159166
# If rejection sampling fails, fall back to polytope sampling.
@@ -184,11 +191,15 @@ def gen(
184191
num_generated = (
185192
len(generated_points) if generated_points is not None else 0
186193
)
187-
interior_point = ( # A feasible point of shape `d x 1`.
188-
torch.from_numpy(generated_points[-1].reshape((-1, 1))).double()
189-
if generated_points is not None
190-
else None
191-
)
194+
# Use a previously generated point as the interior point
195+
# hint, but only if it's likely feasible. When equality
196+
# constraints are present, previous points (generated
197+
# without those constraints) won't satisfy them.
198+
interior_point: torch.Tensor | None = None
199+
if generated_points is not None and equality_constraints is None:
200+
interior_point = torch.from_numpy(
201+
generated_points[-1].reshape((-1, 1))
202+
).double()
192203
kwargs = {"n_burnin": 100, "n_thinning": 20}
193204
kwargs.update(self.polytope_sampler_kwargs)
194205
polytope_sampler: HitAndRunPolytopeSampler = HitAndRunPolytopeSampler(
@@ -229,6 +240,7 @@ def gen_polytope_sampler(
229240
fixed_features=fixed_features,
230241
rounding_func=rounding_func,
231242
existing_points=generated_points,
243+
equality_constraints=equality_constraints,
232244
)
233245
else:
234246
raise e

ax/generators/utils.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def rejection_sample(
6868
fixed_features: dict[int, float] | None = None,
6969
rounding_func: Callable[[npt.NDArray], npt.NDArray] | None = None,
7070
existing_points: npt.NDArray | None = None,
71+
equality_constraints: tuple[npt.NDArray, npt.NDArray] | None = None,
7172
) -> tuple[npt.NDArray, int]:
7273
"""Rejection sample in parameter space.
7374
@@ -96,6 +97,9 @@ def rejection_sample(
9697
existing_points: A set of previously generated points to use
9798
for deduplication. These should be provided in the parameter
9899
space model operates in.
100+
equality_constraints: A tuple of (A, b). For k equality constraints
101+
on d-dimensional x, A is (k x d) and b is (k x 1) such that
102+
A x = b.
99103
100104
Returns:
101105
2-element tuple containing the generated points and the number of
@@ -124,9 +128,26 @@ def rejection_sample(
124128
)[0]
125129

126130
# Check parameter constraints, always in raw transformed space.
131+
has_constraints = (
132+
linear_constraints is not None or equality_constraints is not None
133+
)
127134
if linear_constraints is not None:
128135
all_constraints_satisfied, _ = check_param_constraints(
129-
linear_constraints=linear_constraints, point=point
136+
linear_constraints=linear_constraints,
137+
point=point,
138+
equality_constraints=equality_constraints,
139+
)
140+
elif equality_constraints is not None:
141+
# No inequality constraints but have equality constraints.
142+
# Use a dummy (0, d) inequality matrix so check_param_constraints works.
143+
dummy_ineq = (
144+
np.zeros((0, len(point))),
145+
np.zeros((0, 1)),
146+
)
147+
all_constraints_satisfied, _ = check_param_constraints(
148+
linear_constraints=dummy_ineq,
149+
point=point,
150+
equality_constraints=equality_constraints,
130151
)
131152
else:
132153
all_constraints_satisfied = True
@@ -140,9 +161,15 @@ def rejection_sample(
140161
# Re-check constraints after rounding for discrete parameters
141162
# (e.g. numerical choice parameters) because rounding can push values
142163
# in a direction that violates sum constraints.
143-
if linear_constraints is not None:
164+
if has_constraints:
165+
ineq = linear_constraints or (
166+
np.zeros((0, len(point))),
167+
np.zeros((0, 1)),
168+
)
144169
all_constraints_satisfied, _ = check_param_constraints(
145-
linear_constraints=linear_constraints, point=point
170+
linear_constraints=ineq,
171+
point=point,
172+
equality_constraints=equality_constraints,
146173
)
147174
if not all_constraints_satisfied:
148175
attempted_draws += 1
@@ -228,6 +255,7 @@ def add_fixed_features(
228255
def check_param_constraints(
229256
linear_constraints: tuple[npt.NDArray, npt.NDArray],
230257
point: npt.NDArray,
258+
equality_constraints: tuple[npt.NDArray, npt.NDArray] | None = None,
231259
) -> tuple[bool, npt.NDArray]:
232260
"""Check if a point satisfies parameter constraints.
233261
@@ -236,6 +264,9 @@ def check_param_constraints(
236264
d-dimensional x, A is (k x d) and b is (k x 1) such that
237265
A x <= b.
238266
point: A candidate point in d-dimensional space, as a (1 x d) matrix.
267+
equality_constraints: A tuple of (A, b). For k equality constraints on
268+
d-dimensional x, A is (k x d) and b is (k x 1) such that
269+
A x = b.
239270
240271
Returns:
241272
2-element tuple containing
@@ -246,6 +277,17 @@ def check_param_constraints(
246277
constraints_satisfied = (
247278
linear_constraints[0] @ np.expand_dims(point, axis=1) <= linear_constraints[1]
248279
)
280+
if equality_constraints is not None:
281+
eq_satisfied = (
282+
np.abs(
283+
equality_constraints[0] @ np.expand_dims(point, axis=1)
284+
- equality_constraints[1]
285+
)
286+
<= 1e-8
287+
)
288+
constraints_satisfied = np.concatenate(
289+
[constraints_satisfied, eq_satisfied], axis=0
290+
)
249291
if np.all(constraints_satisfied):
250292
return True, np.array([])
251293
else:

0 commit comments

Comments
 (0)