5757 AnalyticExpectedUtilityOfBestOption ,
5858 qExpectedUtilityOfBestOption ,
5959)
60- from botorch .exceptions .errors import BotorchError , InputDataError
60+ from botorch .exceptions .errors import (
61+ BotorchError ,
62+ CandidateGenerationError ,
63+ InputDataError ,
64+ )
6165from botorch .generation .sampling import SamplingStrategy
6266from botorch .models .model import Model
6367from botorch .optim .optimize import (
7276 optimize_acqf_mixed_alternating ,
7377 should_use_mixed_alternating_optimizer ,
7478)
75- from botorch .optim .parameter_constraints import evaluate_feasibility
79+ from botorch .optim .parameter_constraints import (
80+ evaluate_feasibility ,
81+ project_to_feasible_space_via_slsqp ,
82+ )
7683from botorch .utils .constraints import get_outcome_constraint_transforms
7784from pyre_extensions import assert_is_instance , none_throws
7885from torch import Tensor
@@ -892,6 +899,7 @@ def optimize(
892899 inequality_constraints = inequality_constraints ,
893900 equality_constraints = equality_constraints ,
894901 fixed_features = fixed_features ,
902+ bounds = bounds ,
895903 )
896904 # Validate candidates before returning
897905 validate_candidates (
@@ -1007,6 +1015,7 @@ def _prune_irrelevant_parameters(
10071015 inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
10081016 equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
10091017 fixed_features : dict [int , float ] | None = None ,
1018+ bounds : Tensor | None = None ,
10101019 ) -> tuple [Tensor , Tensor ]:
10111020 r"""Prune irrelevant parameters from the candidates using BONSAI.
10121021
@@ -1042,6 +1051,11 @@ def _prune_irrelevant_parameters(
10421051 corresponds to the `l_i`-th feature of that element.
10431052 fixed_features: A map `{feature_index: value}` for features that
10441053 should be fixed to a particular value during generation.
1054+ bounds: A `2 x d`-dim tensor of lower and upper parameter bounds.
1055+ Required when `inequality_constraints` or `equality_constraints`
1056+ are provided: pruned candidates are projected onto the feasible
1057+ set via SLSQP, and the projection needs the parameter bounds to
1058+ define the feasible region. Unused when no constraints are set.
10451059
10461060 Returns:
10471061 A two-element tuple containing an `q x d`-dim tensor of generated
@@ -1085,12 +1099,14 @@ def _prune_irrelevant_parameters(
10851099 # dense AF val
10861100 final_af_val = dense_af_val
10871101 # If the current incremental AF value is zero, then we skip pruning
1102+ has_constraints = bool (inequality_constraints or equality_constraints )
10881103 if dense_incremental_af_val > 0.0 :
10891104 remaining_indices = set (range (candidates .shape [- 1 ])) - excluded_indices
10901105 # remove features that are already set to target_point
10911106 remaining_indices -= set (
10921107 (candidates [i ] == target_point ).nonzero ().view (- 1 ).tolist ()
10931108 )
1109+ initial_remaining = set (remaining_indices )
10941110 # len(remaining_indices) - 1 is used here so that we do not prune
10951111 # every dimension
10961112 for _ in range (len (remaining_indices ) - 1 ):
@@ -1107,13 +1123,23 @@ def _prune_irrelevant_parameters(
11071123 indices = indices ,
11081124 targets = target_point [indices ],
11091125 )
1110- # remove candidates that violate constraints after pruning
1111- pruned_candidates , indices = _remove_infeasible_candidates (
1112- candidates = pruned_candidates ,
1113- indices = indices ,
1114- inequality_constraints = inequality_constraints ,
1115- equality_constraints = equality_constraints ,
1116- )
1126+ # Project pruned candidates onto the feasible set
1127+ # (pinning the pruned dim and previously pruned dims),
1128+ # then filter any that remain infeasible.
1129+ if has_constraints :
1130+ previously_pruned = initial_remaining - remaining_indices
1131+ pruned_candidates , indices = (
1132+ _project_and_filter_pruned_candidates (
1133+ candidates = pruned_candidates ,
1134+ indices = indices ,
1135+ target_point = target_point ,
1136+ pruned_dims = previously_pruned ,
1137+ bounds = none_throws (bounds ),
1138+ inequality_constraints = inequality_constraints ,
1139+ equality_constraints = equality_constraints ,
1140+ fixed_features = fixed_features ,
1141+ )
1142+ )
11171143 if pruned_candidates .shape [0 ] == 0 :
11181144 # no feasible points, continue to
11191145 # next candidate
@@ -1253,3 +1279,129 @@ def _remove_infeasible_candidates(
12531279 candidates = candidates [is_feasible ]
12541280 indices = indices [is_feasible ]
12551281 return candidates , indices
1282+
1283+
1284+ def _project_and_filter_pruned_candidates (
1285+ candidates : Tensor ,
1286+ indices : Tensor ,
1287+ target_point : Tensor ,
1288+ pruned_dims : set [int ],
1289+ bounds : Tensor ,
1290+ inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
1291+ equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
1292+ fixed_features : dict [int , float ] | None = None ,
1293+ ) -> tuple [Tensor , Tensor ]:
1294+ r"""Project pruned candidates onto the feasible set, then filter infeasible.
1295+
1296+ Helper for ``Acquisition._prune_irrelevant_parameters`` (BONSAI). It is
1297+ only meaningful in the context of that greedy-pruning loop and is not
1298+ intended for standalone use.
1299+
1300+ Background: BONSAI pruning evaluates a candidate dimension-by-dimension
1301+ by setting one dimension at a time to its target-point value. Each row
1302+ of ``candidates`` is one such trial -- the dense candidate with the
1303+ dimension at ``indices[i]`` swapped to ``target_point[indices[i]]``.
1304+ Under linear constraints, swapping a single dimension to the target
1305+ typically violates the constraints; rather than discarding the trial
1306+ (the prior behavior), we adjust the *other* free dimensions to recover
1307+ feasibility while keeping the swapped dimension and all previously
1308+ pruned dimensions pinned. Trials whose pins make the constraint system
1309+ infeasible -- and the rare case where projection succeeds but the
1310+ result still violates constraints -- are filtered out via the mask
1311+ returned to the caller.
1312+
1313+ Args:
1314+ candidates: A ``b x 1 x d``-dim tensor of pruned candidates (one row
1315+ per single-dimension prune attempt for the current BONSAI
1316+ iteration).
1317+ indices: A ``b``-dim tensor indicating which dimension was pruned
1318+ in each batch element.
1319+ target_point: A ``d``-dim tensor of target values for pruning.
1320+ pruned_dims: Set of dimension indices already pruned in prior
1321+ greedy iterations (to be kept pinned during projection).
1322+ bounds: A ``2 x d``-dim tensor of lower and upper bounds.
1323+ inequality_constraints: Inequality constraints in BoTorch format.
1324+ equality_constraints: Equality constraints in BoTorch format.
1325+ fixed_features: A map ``{feature_index: value}`` from the caller.
1326+ These dimensions are excluded from pruning at the outer loop and
1327+ must also be pinned during projection so SLSQP cannot adjust
1328+ them while satisfying the constraints. Without this, fixed
1329+ features could be silently altered.
1330+
1331+ Returns:
1332+ A two-element tuple of filtered ``(candidates, indices)``.
1333+ """
1334+ # Pre-compute which dims participate in any constraint, and check whether
1335+ # any constraint is inter-point (2D index tensor). Inter-point constraints
1336+ # apply across the q-batch, but each row here is a single-candidate prune
1337+ # attempt -- ``project_to_feasible_space_via_slsqp`` cannot evaluate
1338+ # inter-point constraints on a 1 x d input. Fall back to the original
1339+ # filter-only behavior in that case.
1340+ constrained_dims : set [int ] = set ()
1341+ has_interpoint_constraint = False
1342+ for constraints in (inequality_constraints , equality_constraints ):
1343+ if constraints is not None :
1344+ for c_indices , _ , _ in constraints :
1345+ if c_indices .dim () == 1 :
1346+ constrained_dims .update (c_indices .tolist ())
1347+ else :
1348+ constrained_dims .update (c_indices [:, - 1 ].tolist ())
1349+ has_interpoint_constraint = True
1350+ if has_interpoint_constraint :
1351+ return _remove_infeasible_candidates (
1352+ candidates = candidates ,
1353+ indices = indices ,
1354+ inequality_constraints = inequality_constraints ,
1355+ equality_constraints = equality_constraints ,
1356+ )
1357+
1358+ # Build fixed_features for previously pruned dims and the caller's
1359+ # fixed_features (both shared across all candidates in this iteration).
1360+ prev_fixed : dict [int , float ] = {k : target_point [k ].item () for k in pruned_dims }
1361+ if fixed_features is not None :
1362+ prev_fixed .update (fixed_features )
1363+
1364+ feasible_mask = torch .ones (candidates .shape [0 ], dtype = torch .bool )
1365+ result = candidates .clone ()
1366+
1367+ for i in range (candidates .shape [0 ]):
1368+ j : int = int (indices [i ].item ())
1369+ # If the pruned dim doesn't participate in any constraint,
1370+ # pruning it can't violate anything — skip projection.
1371+ if j not in constrained_dims :
1372+ continue
1373+ # Pin the currently pruned dim, all previously pruned dims, and the
1374+ # caller's fixed features.
1375+ fixed : dict [int , float | Tensor ] = {
1376+ j : float (target_point [j ].item ()),
1377+ ** prev_fixed ,
1378+ }
1379+ try :
1380+ projected = project_to_feasible_space_via_slsqp (
1381+ X = candidates [i ], # 1 x d
1382+ bounds = bounds ,
1383+ inequality_constraints = inequality_constraints ,
1384+ equality_constraints = equality_constraints ,
1385+ fixed_features = fixed ,
1386+ )
1387+ result [i ] = projected
1388+ except CandidateGenerationError :
1389+ # Pin makes the system infeasible — mark for removal.
1390+ # The post-projection feasibility check below is the safety net
1391+ # for any candidates that project but still violate constraints.
1392+ feasible_mask [i ] = False
1393+
1394+ # Final safety-net feasibility check after projection.
1395+ if feasible_mask .any ():
1396+ is_feasible = evaluate_feasibility (
1397+ X = result [feasible_mask ],
1398+ inequality_constraints = inequality_constraints ,
1399+ equality_constraints = equality_constraints ,
1400+ )
1401+ # Map back to the full mask.
1402+ feasible_subset_indices = feasible_mask .nonzero (as_tuple = True )[0 ]
1403+ for idx , feas in zip (feasible_subset_indices , is_feasible ):
1404+ if not feas :
1405+ feasible_mask [idx ] = False
1406+
1407+ return result [feasible_mask ], indices [feasible_mask ]
0 commit comments