Skip to content

Commit cd36879

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
BONSAI: pin-and-project for constrained pruning (#5180)
Summary: Pull Request resolved: #5180 Extend BONSAI's irrelevance pruning to handle both equality and inequality constraints via a pin-and-project approach. Previously, BONSAI simply discarded pruned candidates that violated constraints. This was overly conservative (inequality) or completely broken (equality, where almost all single-dimension prunes violate the constraint). The new approach: 1. Set x_j = target[j] (unchanged) 2. Project the other dimensions onto the feasible set via SLSQP, keeping x_j pinned (and all previously pruned dims pinned) 3. Filter any candidates that remain infeasible after projection This is strictly better than discarding: it recovers feasibility when possible by adjusting other dimensions, while infeasible pins (where no adjustment can satisfy the constraints) are still caught. Key implementation details: - `_project_and_filter_pruned_candidates`: new function that uses `project_to_feasible_space_via_slsqp` with `fixed_features` to pin the pruned dim and all previously pruned dims. - Optimization: skip projection for dims not in any constraint's index set (pruning them can't violate anything). - Handles 2D inter-point constraint indices correctly. - `_prune_irrelevant_parameters` now accepts `bounds` parameter. Reviewed By: esantorella Differential Revision: D100256483 fbshipit-source-id: 5acb90bedcf38950ce8d3a56ee53ffa23d1f5caf
1 parent 75c0ed4 commit cd36879

2 files changed

Lines changed: 286 additions & 9 deletions

File tree

ax/generators/torch/botorch_modular/acquisition.py

Lines changed: 161 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@
5757
AnalyticExpectedUtilityOfBestOption,
5858
qExpectedUtilityOfBestOption,
5959
)
60-
from botorch.exceptions.errors import BotorchError, InputDataError
60+
from botorch.exceptions.errors import (
61+
BotorchError,
62+
CandidateGenerationError,
63+
InputDataError,
64+
)
6165
from botorch.generation.sampling import SamplingStrategy
6266
from botorch.models.model import Model
6367
from botorch.optim.optimize import (
@@ -72,7 +76,10 @@
7276
optimize_acqf_mixed_alternating,
7377
should_use_mixed_alternating_optimizer,
7478
)
75-
from botorch.optim.parameter_constraints import evaluate_feasibility
79+
from botorch.optim.parameter_constraints import (
80+
evaluate_feasibility,
81+
project_to_feasible_space_via_slsqp,
82+
)
7683
from botorch.utils.constraints import get_outcome_constraint_transforms
7784
from pyre_extensions import assert_is_instance, none_throws
7885
from torch import Tensor
@@ -892,6 +899,7 @@ def optimize(
892899
inequality_constraints=inequality_constraints,
893900
equality_constraints=equality_constraints,
894901
fixed_features=fixed_features,
902+
bounds=bounds,
895903
)
896904
# Validate candidates before returning
897905
validate_candidates(
@@ -1007,6 +1015,7 @@ def _prune_irrelevant_parameters(
10071015
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
10081016
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
10091017
fixed_features: dict[int, float] | None = None,
1018+
bounds: Tensor | None = None,
10101019
) -> tuple[Tensor, Tensor]:
10111020
r"""Prune irrelevant parameters from the candidates using BONSAI.
10121021
@@ -1042,6 +1051,11 @@ def _prune_irrelevant_parameters(
10421051
corresponds to the `l_i`-th feature of that element.
10431052
fixed_features: A map `{feature_index: value}` for features that
10441053
should be fixed to a particular value during generation.
1054+
bounds: A `2 x d`-dim tensor of lower and upper parameter bounds.
1055+
Required when `inequality_constraints` or `equality_constraints`
1056+
are provided: pruned candidates are projected onto the feasible
1057+
set via SLSQP, and the projection needs the parameter bounds to
1058+
define the feasible region. Unused when no constraints are set.
10451059
10461060
Returns:
10471061
A two-element tuple containing an `q x d`-dim tensor of generated
@@ -1085,12 +1099,14 @@ def _prune_irrelevant_parameters(
10851099
# dense AF val
10861100
final_af_val = dense_af_val
10871101
# If the current incremental AF value is zero, then we skip pruning
1102+
has_constraints = bool(inequality_constraints or equality_constraints)
10881103
if dense_incremental_af_val > 0.0:
10891104
remaining_indices = set(range(candidates.shape[-1])) - excluded_indices
10901105
# remove features that are already set to target_point
10911106
remaining_indices -= set(
10921107
(candidates[i] == target_point).nonzero().view(-1).tolist()
10931108
)
1109+
initial_remaining = set(remaining_indices)
10941110
# len(remaining_indices) - 1 is used here so that we do not prune
10951111
# every dimension
10961112
for _ in range(len(remaining_indices) - 1):
@@ -1107,13 +1123,23 @@ def _prune_irrelevant_parameters(
11071123
indices=indices,
11081124
targets=target_point[indices],
11091125
)
1110-
# remove candidates that violate constraints after pruning
1111-
pruned_candidates, indices = _remove_infeasible_candidates(
1112-
candidates=pruned_candidates,
1113-
indices=indices,
1114-
inequality_constraints=inequality_constraints,
1115-
equality_constraints=equality_constraints,
1116-
)
1126+
# Project pruned candidates onto the feasible set
1127+
# (pinning the pruned dim and previously pruned dims),
1128+
# then filter any that remain infeasible.
1129+
if has_constraints:
1130+
previously_pruned = initial_remaining - remaining_indices
1131+
pruned_candidates, indices = (
1132+
_project_and_filter_pruned_candidates(
1133+
candidates=pruned_candidates,
1134+
indices=indices,
1135+
target_point=target_point,
1136+
pruned_dims=previously_pruned,
1137+
bounds=none_throws(bounds),
1138+
inequality_constraints=inequality_constraints,
1139+
equality_constraints=equality_constraints,
1140+
fixed_features=fixed_features,
1141+
)
1142+
)
11171143
if pruned_candidates.shape[0] == 0:
11181144
# no feasible points, continue to
11191145
# next candidate
@@ -1253,3 +1279,129 @@ def _remove_infeasible_candidates(
12531279
candidates = candidates[is_feasible]
12541280
indices = indices[is_feasible]
12551281
return candidates, indices
1282+
1283+
1284+
def _project_and_filter_pruned_candidates(
1285+
candidates: Tensor,
1286+
indices: Tensor,
1287+
target_point: Tensor,
1288+
pruned_dims: set[int],
1289+
bounds: Tensor,
1290+
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
1291+
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
1292+
fixed_features: dict[int, float] | None = None,
1293+
) -> tuple[Tensor, Tensor]:
1294+
r"""Project pruned candidates onto the feasible set, then filter infeasible.
1295+
1296+
Helper for ``Acquisition._prune_irrelevant_parameters`` (BONSAI). It is
1297+
only meaningful in the context of that greedy-pruning loop and is not
1298+
intended for standalone use.
1299+
1300+
Background: BONSAI pruning evaluates a candidate dimension-by-dimension
1301+
by setting one dimension at a time to its target-point value. Each row
1302+
of ``candidates`` is one such trial -- the dense candidate with the
1303+
dimension at ``indices[i]`` swapped to ``target_point[indices[i]]``.
1304+
Under linear constraints, swapping a single dimension to the target
1305+
typically violates the constraints; rather than discarding the trial
1306+
(the prior behavior), we adjust the *other* free dimensions to recover
1307+
feasibility while keeping the swapped dimension and all previously
1308+
pruned dimensions pinned. Trials whose pins make the constraint system
1309+
infeasible -- and the rare case where projection succeeds but the
1310+
result still violates constraints -- are filtered out via the mask
1311+
returned to the caller.
1312+
1313+
Args:
1314+
candidates: A ``b x 1 x d``-dim tensor of pruned candidates (one row
1315+
per single-dimension prune attempt for the current BONSAI
1316+
iteration).
1317+
indices: A ``b``-dim tensor indicating which dimension was pruned
1318+
in each batch element.
1319+
target_point: A ``d``-dim tensor of target values for pruning.
1320+
pruned_dims: Set of dimension indices already pruned in prior
1321+
greedy iterations (to be kept pinned during projection).
1322+
bounds: A ``2 x d``-dim tensor of lower and upper bounds.
1323+
inequality_constraints: Inequality constraints in BoTorch format.
1324+
equality_constraints: Equality constraints in BoTorch format.
1325+
fixed_features: A map ``{feature_index: value}`` from the caller.
1326+
These dimensions are excluded from pruning at the outer loop and
1327+
must also be pinned during projection so SLSQP cannot adjust
1328+
them while satisfying the constraints. Without this, fixed
1329+
features could be silently altered.
1330+
1331+
Returns:
1332+
A two-element tuple of filtered ``(candidates, indices)``.
1333+
"""
1334+
# Pre-compute which dims participate in any constraint, and check whether
1335+
# any constraint is inter-point (2D index tensor). Inter-point constraints
1336+
# apply across the q-batch, but each row here is a single-candidate prune
1337+
# attempt -- ``project_to_feasible_space_via_slsqp`` cannot evaluate
1338+
# inter-point constraints on a 1 x d input. Fall back to the original
1339+
# filter-only behavior in that case.
1340+
constrained_dims: set[int] = set()
1341+
has_interpoint_constraint = False
1342+
for constraints in (inequality_constraints, equality_constraints):
1343+
if constraints is not None:
1344+
for c_indices, _, _ in constraints:
1345+
if c_indices.dim() == 1:
1346+
constrained_dims.update(c_indices.tolist())
1347+
else:
1348+
constrained_dims.update(c_indices[:, -1].tolist())
1349+
has_interpoint_constraint = True
1350+
if has_interpoint_constraint:
1351+
return _remove_infeasible_candidates(
1352+
candidates=candidates,
1353+
indices=indices,
1354+
inequality_constraints=inequality_constraints,
1355+
equality_constraints=equality_constraints,
1356+
)
1357+
1358+
# Build fixed_features for previously pruned dims and the caller's
1359+
# fixed_features (both shared across all candidates in this iteration).
1360+
prev_fixed: dict[int, float] = {k: target_point[k].item() for k in pruned_dims}
1361+
if fixed_features is not None:
1362+
prev_fixed.update(fixed_features)
1363+
1364+
feasible_mask = torch.ones(candidates.shape[0], dtype=torch.bool)
1365+
result = candidates.clone()
1366+
1367+
for i in range(candidates.shape[0]):
1368+
j: int = int(indices[i].item())
1369+
# If the pruned dim doesn't participate in any constraint,
1370+
# pruning it can't violate anything — skip projection.
1371+
if j not in constrained_dims:
1372+
continue
1373+
# Pin the currently pruned dim, all previously pruned dims, and the
1374+
# caller's fixed features.
1375+
fixed: dict[int, float | Tensor] = {
1376+
j: float(target_point[j].item()),
1377+
**prev_fixed,
1378+
}
1379+
try:
1380+
projected = project_to_feasible_space_via_slsqp(
1381+
X=candidates[i], # 1 x d
1382+
bounds=bounds,
1383+
inequality_constraints=inequality_constraints,
1384+
equality_constraints=equality_constraints,
1385+
fixed_features=fixed,
1386+
)
1387+
result[i] = projected
1388+
except CandidateGenerationError:
1389+
# Pin makes the system infeasible — mark for removal.
1390+
# The post-projection feasibility check below is the safety net
1391+
# for any candidates that project but still violate constraints.
1392+
feasible_mask[i] = False
1393+
1394+
# Final safety-net feasibility check after projection.
1395+
if feasible_mask.any():
1396+
is_feasible = evaluate_feasibility(
1397+
X=result[feasible_mask],
1398+
inequality_constraints=inequality_constraints,
1399+
equality_constraints=equality_constraints,
1400+
)
1401+
# Map back to the full mask.
1402+
feasible_subset_indices = feasible_mask.nonzero(as_tuple=True)[0]
1403+
for idx, feas in zip(feasible_subset_indices, is_feasible):
1404+
if not feas:
1405+
feasible_mask[idx] = False
1406+
1407+
return result[feasible_mask], indices[feasible_mask]

ax/generators/torch/tests/test_acquisition.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,6 +1831,7 @@ def test_prune_irrelevant_parameters_with_inequality_constraints(self) -> None:
18311831
candidates=candidates,
18321832
search_space_digest=search_space_digest,
18331833
inequality_constraints=inequality_constraints,
1834+
bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]]),
18341835
)
18351836
self.assertTrue(torch.equal(pruned_candidates, torch.tensor([[0.2, 0.8]])))
18361837
self.assertTrue(torch.equal(pruned_values, torch.tensor([0.91])))
@@ -1848,6 +1849,7 @@ def test_prune_irrelevant_parameters_with_inequality_constraints(self) -> None:
18481849
inequality_constraints=[
18491850
(torch.tensor([0, 1]), torch.tensor([1.0, 1.0]), 1.5)
18501851
],
1852+
bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]]),
18511853
)
18521854
# No pruning: setting either dim to 0.2 gives sum=1.0 < 1.5 (infeasible)
18531855
self.assertTrue(torch.equal(pruned_candidates, torch.tensor([[0.8, 0.8]])))
@@ -2055,13 +2057,136 @@ def test_prune_irrelevant_parameters_with_constraints_exact_values(self) -> None
20552057
1.0,
20562058
)
20572059
],
2060+
bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]]),
20582061
)
20592062

20602063
# Only dimension 0 should be pruned
20612064
expected_candidate = torch.tensor([[0.1, 1.0]])
20622065
self.assertTrue(torch.equal(pruned_candidates, expected_candidate))
20632066
self.assertTrue(torch.equal(pruned_values, torch.tensor([1.0])))
20642067

2068+
def test_prune_irrelevant_parameters_with_equality_constraints(self) -> None:
2069+
# Test pruning with an equality constraint (x1 + x2 + x3 = 1).
2070+
# When a dimension is pruned to its target, the remaining dims should
2071+
# be projected onto the equality constraint hyperplane.
2072+
search_space_digest = SearchSpaceDigest(
2073+
feature_names=["x1", "x2", "x3"],
2074+
bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)],
2075+
)
2076+
target_point = torch.tensor([1.0 / 3, 1.0 / 3, 1.0 / 3])
2077+
acq = Acquisition(
2078+
surrogate=self.surrogate,
2079+
search_space_digest=search_space_digest,
2080+
torch_opt_config=dataclasses.replace(
2081+
self.torch_opt_config,
2082+
pruning_target_point=target_point,
2083+
),
2084+
botorch_acqf_class=DummyAcquisitionFunction,
2085+
)
2086+
mock_acqf = Mock()
2087+
mock_acqf._log = False
2088+
acq.acqf = mock_acqf
2089+
acq._instantiate_acquisition = Mock()
2090+
2091+
# Candidate that satisfies x1 + x2 + x3 = 1.
2092+
candidates = torch.tensor([[0.5, 0.3, 0.2]])
2093+
# Equality constraint: x1 + x2 + x3 = 1
2094+
equality_constraints = [
2095+
(torch.tensor([0, 1, 2]), torch.tensor([1.0, 1.0, 1.0]), 1.0)
2096+
]
2097+
bounds = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])
2098+
2099+
mock_evaluate = Mock(
2100+
side_effect=[
2101+
torch.tensor([0.0]), # baseline af val
2102+
torch.tensor([1.0]), # dense af val
2103+
# After pruning dim 0 to 1/3 and projecting, the candidate
2104+
# still satisfies x1+x2+x3=1. Two pruning candidates
2105+
# (dim 1 and dim 2) survive projection.
2106+
torch.tensor([0.95, 0.90]), # pruned af vals
2107+
torch.tensor([0.93]), # second round pruned af val
2108+
]
2109+
)
2110+
acq.evaluate = mock_evaluate
2111+
2112+
pruned_candidates, pruned_values = acq._prune_irrelevant_parameters(
2113+
candidates=candidates,
2114+
search_space_digest=search_space_digest,
2115+
equality_constraints=equality_constraints,
2116+
bounds=bounds,
2117+
)
2118+
# Verify that pruning occurred and the result satisfies the constraint.
2119+
self.assertEqual(pruned_candidates.shape[-1], 3)
2120+
for i in range(pruned_candidates.shape[0]):
2121+
self.assertAlmostEqual(
2122+
pruned_candidates[i].sum().item(),
2123+
1.0,
2124+
places=4,
2125+
)
2126+
2127+
def test_prune_irrelevant_parameters_fixed_features_pinned_in_projection(
2128+
self,
2129+
) -> None:
2130+
# When constraints are active and `fixed_features` is provided, the
2131+
# SLSQP projection must pin the fixed dims so they cannot be silently
2132+
# adjusted to satisfy the constraint.
2133+
search_space_digest = SearchSpaceDigest(
2134+
feature_names=["x1", "x2", "x3"],
2135+
bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)],
2136+
)
2137+
target_point = torch.tensor([1.0 / 3, 1.0 / 3, 1.0 / 3])
2138+
acq = Acquisition(
2139+
surrogate=self.surrogate,
2140+
search_space_digest=search_space_digest,
2141+
torch_opt_config=dataclasses.replace(
2142+
self.torch_opt_config,
2143+
pruning_target_point=target_point,
2144+
),
2145+
botorch_acqf_class=DummyAcquisitionFunction,
2146+
)
2147+
mock_acqf = Mock()
2148+
mock_acqf._log = False
2149+
acq.acqf = mock_acqf
2150+
acq._instantiate_acquisition = Mock()
2151+
2152+
# Candidate that satisfies x1 + x2 + x3 = 1 with x1 fixed at 0.6.
2153+
candidates = torch.tensor([[0.6, 0.3, 0.1]])
2154+
# Equality constraint: x1 + x2 + x3 = 1
2155+
equality_constraints = [
2156+
(torch.tensor([0, 1, 2]), torch.tensor([1.0, 1.0, 1.0]), 1.0)
2157+
]
2158+
bounds = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])
2159+
# Fix x1 to its current value. Pruning dim 1 (x2 -> 1/3) breaks the
2160+
# constraint; without pinning x1 in the projection, SLSQP could move
2161+
# x1 to recover feasibility, silently overwriting the fixed value.
2162+
fixed_features = {0: 0.6}
2163+
2164+
mock_evaluate = Mock(
2165+
side_effect=[
2166+
torch.tensor([0.0]), # baseline af val
2167+
torch.tensor([1.0]), # dense af val
2168+
# Only dim 1 and dim 2 are eligible (dim 0 is fixed). Both
2169+
# pruning attempts should yield projected candidates that
2170+
# keep x1 == 0.6 exactly.
2171+
torch.tensor([0.95, 0.90]), # pruned af vals
2172+
torch.tensor([0.93]), # second-round pruned af val
2173+
]
2174+
)
2175+
acq.evaluate = mock_evaluate
2176+
2177+
pruned_candidates, _ = acq._prune_irrelevant_parameters(
2178+
candidates=candidates,
2179+
search_space_digest=search_space_digest,
2180+
equality_constraints=equality_constraints,
2181+
bounds=bounds,
2182+
fixed_features=fixed_features,
2183+
)
2184+
# The fixed feature must be preserved exactly through projection,
2185+
# and the constraint must still be satisfied.
2186+
self.assertEqual(pruned_candidates.shape[-1], 3)
2187+
self.assertAlmostEqual(pruned_candidates[0, 0].item(), 0.6, places=6)
2188+
self.assertAlmostEqual(pruned_candidates[0].sum().item(), 1.0, places=4)
2189+
20652190
def test_prune_irrelevant_parameters_with_task_and_fidelity_features(self) -> None:
20662191
# Test pruning with both task and fidelity features that should be excluded
20672192
# from pruning

0 commit comments

Comments
 (0)