Skip to content

Commit ad68879

Browse files
David Erikssonmeta-codesync[bot]
authored andcommitted
Fix bug in optimize_acqf_mixed_alternating that may produce candidates that have invalid values (meta-pytorch#3212)
Summary: Pull Request resolved: meta-pytorch#3212 When using parameter constraints placed on discrete parameters, `optimize_acqf_mixed_alternating` may produce candidates that have invalid values due to a weird interaction between `project_to_feasible_space_via_slsqp` and the `post_processing_func`. For example, `project_to_feasible_space_via_slsqp` may end up moving a discrete parameter during the continuous step which causes it to later be rounded to an invalid value by the `post_processing_func`. The solution is to do two things: 1. Fix all discrete parameters during the continuous step so they aren't modified by `project_to_feasible_space_via_slsqp`. 2. Modify `_setup_continuous_relaxation` to not apply continuous relaxation to discrete parameters that are part of a parameter constraint (independently of their cardinality). Reviewed By: saitcakmak, ltiao Differential Revision: D94963154 fbshipit-source-id: 30b952f93a804736a90d287ececc6c3c57e8ba89
1 parent 208971f commit ad68879

5 files changed

Lines changed: 254 additions & 10 deletions

File tree

botorch/optim/optimize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
535535
bounds=opt_inputs.bounds,
536536
equality_constraints=equality_constraints,
537537
inequality_constraints=inequality_constraints,
538+
fixed_features=opt_inputs.fixed_features,
538539
)
539540
if opt_inputs.post_processing_func is not None:
540541
projected_candidates = opt_inputs.post_processing_func(projected_candidates)

botorch/optim/optimize_mixed.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,32 @@ def _setup_continuous_relaxation(
121121
discrete_dims: dict[int, list[float]],
122122
max_discrete_values: int,
123123
post_processing_func: Callable[[Tensor], Tensor] | None,
124+
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
125+
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
124126
) -> tuple[list[int], Callable[[Tensor], Tensor] | None]:
125127
r"""Update ``discrete_dims`` and ``post_processing_func`` to use
126128
continuous relaxation for discrete dimensions that have more than
127129
``max_discrete_values`` values. These dimensions are removed from
128130
``discrete_dims`` and ``post_processing_func`` is updated to round
129131
them to the nearest integer.
132+
133+
Dimensions that participate in constraints are NOT relaxed, as rounding
134+
after projection could violate those constraints.
130135
"""
131136

137+
# Identify dimensions involved in constraints
138+
constrained_dims: set[int] = set()
139+
for constraints in [inequality_constraints, equality_constraints]:
140+
if constraints is not None:
141+
for indices, _, _ in constraints:
142+
constrained_dims.update(indices.tolist())
143+
132144
dims_to_relax, dims_to_keep = {}, {}
133145
for index, values in discrete_dims.items():
134-
if len(values) > max_discrete_values:
146+
# Don't relax dimensions that participate in constraints
147+
if index in constrained_dims:
148+
dims_to_keep[index] = values
149+
elif len(values) > max_discrete_values:
135150
dims_to_relax[index] = values
136151
else:
137152
dims_to_keep[index] = values
@@ -839,8 +854,7 @@ def continuous_step(
839854
This function utilizes ``acq_function``, ``bounds``, ``options``,
840855
``fixed_features`` and constraints from ``opt_inputs``.
841856
``opt_inputs.return_best_only`` should be ``False``.
842-
discrete_dims: A dictionary mapping indices of discrete dimensions
843-
to a list of allowed values for that dimension.
857+
discrete_dims: A tensor of indices corresponding to discrete dimensions.
844858
cat_dims: A tensor of indices corresponding to categorical parameters.
845859
current_x: Starting point. A tensor of shape ``b x d``.
846860
@@ -1032,6 +1046,8 @@ def optimize_acqf_mixed_alternating(
10321046
options.get("max_discrete_values", MAX_DISCRETE_VALUES), int
10331047
),
10341048
post_processing_func=post_processing_func,
1049+
inequality_constraints=inequality_constraints,
1050+
equality_constraints=equality_constraints,
10351051
)
10361052

10371053
opt_inputs = OptimizeAcqfInputs(

botorch/optim/parameter_constraints.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy.typing as npt
1818
import torch
1919
from botorch.exceptions.errors import CandidateGenerationError, UnsupportedError
20-
from botorch.optim.utils import columnwise_clamp
20+
from botorch.optim.utils import columnwise_clamp, fix_features as apply_fix_features
2121
from scipy.optimize import Bounds, minimize
2222
from torch import Tensor
2323

@@ -724,6 +724,7 @@ def project_to_feasible_space_via_slsqp(
724724
bounds: Tensor,
725725
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
726726
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
727+
fixed_features: dict[int, float | Tensor] | None = None,
727728
) -> Tensor:
728729
"""Project X onto the feasible space by solving a quadratic program.
729730
@@ -743,15 +744,38 @@ def project_to_feasible_space_via_slsqp(
743744
``coefficients`` should be torch tensors. See the docstring of
744745
``make_scipy_linear_constraints`` for an example.
745746
equality_constraints: A list of tuples (indices, coefficients, rhs).
747+
fixed_features: A dictionary mapping feature indices to their fixed values.
748+
These dimensions will not be modified during projection. Values can be
749+
scalars (applied to all elements) or 1D tensors matching the batch size
750+
of X (for per-element fixed values).
746751
747752
Returns:
748-
A ``(batch_shape x) n x d``-dim tensor of projected values.
753+
A ``(batch_shape x) n x d``-dim tensor of projected values.
749754
"""
750755
if inequality_constraints is None and equality_constraints is None:
751756
return X
752-
bounds_scipy = make_scipy_bounds(
753-
X=X, lower_bounds=bounds[0], upper_bounds=bounds[1]
754-
)
757+
758+
d = X.shape[-1]
759+
lb = _arrayify(bounds[0].expand_as(X)).flatten()
760+
ub = _arrayify(bounds[1].expand_as(X)).flatten()
761+
762+
# If there are fixed features, constrain those dimensions by setting their
763+
# bounds to equal the current value. This prevents the optimizer from
764+
# modifying them during projection. We use fix_features to apply the fixed
765+
# values to X, then extract the values for setting the bounds.
766+
if fixed_features:
767+
X_fixed = apply_fix_features(X, fixed_features, replace_current_value=True)
768+
# Set bounds for fixed dimensions to match the fixed values
769+
X_fixed_flat = _arrayify(X_fixed).flatten()
770+
for idx in fixed_features.keys():
771+
# For each row in the flattened structure, set bounds at dimension idx
772+
n_rows = X.numel() // d
773+
for i in range(n_rows):
774+
flat_idx = i * d + idx
775+
lb[flat_idx] = X_fixed_flat[flat_idx]
776+
ub[flat_idx] = X_fixed_flat[flat_idx]
777+
778+
bounds_scipy = Bounds(lb=lb, ub=ub, keep_feasible=True)
755779
constraints = make_scipy_linear_constraints(
756780
shapeX=X.shape,
757781
inequality_constraints=inequality_constraints,
@@ -789,6 +813,6 @@ def grad_objective(x: np.ndarray):
789813
)
790814

791815
if not result.success:
792-
raise RuntimeError(f"Optimization failed: {result.message}")
816+
raise CandidateGenerationError(f"Optimization failed: {result.message}")
793817

794818
return torch.from_numpy(result.x).to(X).view(X.shape)

test/optim/test_optimize_mixed.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,6 +1440,8 @@ def test_optimize_acqf_mixed_continuous_relaxation(self) -> None:
14401440
discrete_dims=discrete_dims,
14411441
max_discrete_values=max_discrete_values or MAX_DISCRETE_VALUES,
14421442
post_processing_func=post_processing_func,
1443+
inequality_constraints=None,
1444+
equality_constraints=None,
14431445
)
14441446
discrete_call_args = wrapped_discrete.call_args.kwargs
14451447
expected_dims = [0, 4] if max_discrete_values is None else [0]
@@ -1516,3 +1518,113 @@ def org_post_proc_func(X: Tensor) -> Tensor:
15161518
# Check that generated points are rounded.
15171519
self.assertEqual(X.shape, torch.Size([4, train_X.shape[-1]]))
15181520
self.assertAllClose(X[..., all_integer_dims], X[..., all_integer_dims].round())
1521+
1522+
def test_setup_continuous_relaxation_excludes_constrained_dims(self) -> None:
1523+
"""Test that _setup_continuous_relaxation keeps constrained discrete dims."""
1524+
for dtype in (torch.float, torch.double):
1525+
# Setup: 3 discrete dimensions
1526+
# - Dim 0: Low cardinality (2 values) - kept regardless
1527+
# - Dim 1: High cardinality (50 values), participates in constraint - kept
1528+
# - Dim 2: High cardinality (50 values), not constrained - relaxed
1529+
discrete_dims: dict[int, list[float]] = {
1530+
0: [0.0, 1.0], # Low cardinality - should be kept
1531+
1: list(range(50)), # High cardinality, constrained - should be kept
1532+
2: list(range(50)), # High cardinality, not constrained - relaxed
1533+
}
1534+
max_discrete_values = 20
1535+
# Constraint on dim 1: x[1] >= 10
1536+
inequality_constraints = [
1537+
(
1538+
torch.tensor([1], dtype=torch.long, device=self.device),
1539+
torch.tensor([1.0], dtype=dtype, device=self.device),
1540+
10.0,
1541+
)
1542+
]
1543+
# Execute: call _setup_continuous_relaxation
1544+
dims_kept, post_processing_func = _setup_continuous_relaxation(
1545+
discrete_dims=discrete_dims,
1546+
max_discrete_values=max_discrete_values,
1547+
post_processing_func=None,
1548+
inequality_constraints=inequality_constraints,
1549+
)
1550+
# Assert: dims 0 and 1 are kept (low cardinality and constrained)
1551+
self.assertIn(0, dims_kept)
1552+
self.assertIn(1, dims_kept)
1553+
# Assert: dim 2 is NOT in dims_kept (relaxed)
1554+
self.assertNotIn(2, dims_kept)
1555+
# Assert: post_processing_func is not None since dim 2 was relaxed
1556+
self.assertIsNotNone(post_processing_func)
1557+
# Assert: post_processing_func rounds dim 2 but not dims 0 or 1
1558+
X = torch.tensor(
1559+
[0.4, 25.3, 30.7], # dim 0, 1, 2 with non-integer values
1560+
dtype=dtype,
1561+
device=self.device,
1562+
)
1563+
X_processed = post_processing_func(X)
1564+
# Dim 0 and 1 should remain unchanged (not rounded by this func)
1565+
self.assertAllClose(
1566+
X_processed[0], torch.tensor(0.4, dtype=dtype, device=self.device)
1567+
)
1568+
self.assertAllClose(
1569+
X_processed[1], torch.tensor(25.3, dtype=dtype, device=self.device)
1570+
)
1571+
# Dim 2 should be rounded to nearest valid value
1572+
self.assertAllClose(
1573+
X_processed[2], torch.tensor(31.0, dtype=dtype, device=self.device)
1574+
)
1575+
1576+
def test_optimize_acqf_mixed_alternating_constrained_discrete_dims(self) -> None:
1577+
"""Test full workflow produces valid discrete values with constrained dims.
1578+
1579+
Uses non-contiguous choices [8, 16, 24, 32, 40, 48] to exercise the failure
1580+
mode where rounding to nearest integer (e.g. 47) differs from rounding to
1581+
nearest valid choice (48).
1582+
"""
1583+
for dtype in (torch.float, torch.double):
1584+
# Setup: GP model with posterior mean as acquisition function
1585+
d = 2 # 1 continuous + 1 discrete dimension
1586+
train_X = torch.rand(5, d, dtype=dtype, device=self.device)
1587+
# Non-contiguous discrete values: multiples of 8 from 8 to 48
1588+
valid_choices = [8.0, 16.0, 24.0, 32.0, 40.0, 48.0]
1589+
train_X[:, 1] = torch.tensor(
1590+
[valid_choices[i % len(valid_choices)] for i in range(5)],
1591+
dtype=dtype,
1592+
device=self.device,
1593+
)
1594+
train_Y = train_X.sum(dim=-1, keepdim=True)
1595+
model = SingleTaskGP(train_X, train_Y)
1596+
acqf = PosteriorMean(model=model)
1597+
# Define bounds: [0, 1] for continuous, [8, 48] for discrete
1598+
bounds = torch.tensor(
1599+
[[0.0, 8.0], [1.0, 48.0]], dtype=dtype, device=self.device
1600+
)
1601+
# Non-contiguous discrete dimension (6 values)
1602+
discrete_dims: dict[int, list[float]] = {1: valid_choices}
1603+
# Constraint: x[1] >= 20 (discrete dim must be at least 20)
1604+
inequality_constraints = [
1605+
(
1606+
torch.tensor([1], dtype=torch.long, device=self.device),
1607+
torch.tensor([1.0], dtype=dtype, device=self.device),
1608+
20.0,
1609+
)
1610+
]
1611+
X, _ = optimize_acqf_mixed_alternating(
1612+
acq_function=acqf,
1613+
bounds=bounds,
1614+
discrete_dims=discrete_dims,
1615+
q=1,
1616+
num_restarts=2,
1617+
raw_samples=32,
1618+
inequality_constraints=inequality_constraints,
1619+
options={"max_discrete_values": 2, "maxiter_alternating": 4},
1620+
)
1621+
# Assert: discrete value is within the valid set (not just rounded int)
1622+
valid_choices_tensor = torch.tensor(
1623+
valid_choices, dtype=dtype, device=self.device
1624+
)
1625+
self.assertTrue(
1626+
torch.all(torch.isin(X[..., 1], valid_choices_tensor)),
1627+
f"Returned candidate {X[..., 1].item()} not in {valid_choices}",
1628+
)
1629+
# Assert: constraint is satisfied (x[1] >= 20)
1630+
self.assertTrue(torch.all(X[..., 1] >= 20.0 - 1e-6))

test/optim/test_parameter_constraints.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,9 @@ def test_project_to_feasible_space_via_slsqp_exception(self, _: mock.Mock) -> No
10031003
bounds = torch.tensor([[0.0, 0.0], [2.0, 2.0]], device=self.device)
10041004

10051005
X = torch.tensor([[1.0, 1.0]], device=self.device)
1006-
with self.assertRaisesRegex(RuntimeError, "Optimization failed: failed reason"):
1006+
with self.assertRaisesRegex(
1007+
CandidateGenerationError, "Optimization failed: failed reason"
1008+
):
10071009
project_to_feasible_space_via_slsqp(
10081010
X=X,
10091011
bounds=bounds,
@@ -1015,3 +1017,92 @@ def test_project_to_feasible_space_via_slsqp_exception(self, _: mock.Mock) -> No
10151017
)
10161018
],
10171019
)
1020+
1021+
def test_project_to_feasible_space_with_scalar_fixed_features(self) -> None:
1022+
"""Test projection preserves scalar fixed_features values."""
1023+
for dtype in (torch.float, torch.double):
1024+
tol = get_constraint_tolerance(dtype=dtype)
1025+
# Setup: 3D search space, bounds [[0, 0, 0], [2, 2, 2]]
1026+
bounds = torch.tensor(
1027+
[[0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], dtype=dtype, device=self.device
1028+
)
1029+
# Constraint: x[0] + x[1] >= 1.5
1030+
inequality_constraints = [
1031+
(
1032+
torch.tensor([0, 1], dtype=torch.long, device=self.device),
1033+
torch.tensor([1.0, 1.0], dtype=dtype, device=self.device),
1034+
1.5,
1035+
)
1036+
]
1037+
# Infeasible point X = [[0.3, 0.3, 1.0]] (0.6 < 1.5)
1038+
X = torch.tensor([[0.3, 0.3, 1.0]], dtype=dtype, device=self.device)
1039+
# fixed_features = {0: 0.3} (scalar)
1040+
fixed_features: dict[int, float | torch.Tensor] = {0: 0.3}
1041+
# Execute: project to feasible space with fixed_features
1042+
projected = project_to_feasible_space_via_slsqp(
1043+
X=X,
1044+
bounds=bounds,
1045+
inequality_constraints=inequality_constraints,
1046+
fixed_features=fixed_features,
1047+
)
1048+
# Assert: x[0] remains at 0.3 (fixed)
1049+
self.assertAllClose(
1050+
projected[0, 0], torch.tensor(0.3, dtype=dtype, device=self.device)
1051+
)
1052+
# Assert: constraint is satisfied (x[0] + x[1] >= 1.5)
1053+
self.assertGreaterEqual(
1054+
(projected[0, 0] + projected[0, 1]).item(), 1.5 - tol
1055+
)
1056+
# Assert: bounds are respected
1057+
self.assertTrue(torch.all(projected >= bounds[0] - tol))
1058+
self.assertTrue(torch.all(projected <= bounds[1] + tol))
1059+
1060+
def test_project_to_feasible_space_with_batched_fixed_features(self) -> None:
1061+
"""Test projection preserves batched (tensor) fixed_features values."""
1062+
for dtype in (torch.float, torch.double):
1063+
tol = get_constraint_tolerance(dtype=dtype)
1064+
# Setup: 3D search space, bounds [[0, 0, 0], [2, 2, 2]]
1065+
bounds = torch.tensor(
1066+
[[0.0, 0.0, 0.0], [2.0, 2.0, 2.0]], dtype=dtype, device=self.device
1067+
)
1068+
# Constraint: x[0] + x[1] >= 1.5
1069+
inequality_constraints = [
1070+
(
1071+
torch.tensor([0, 1], dtype=torch.long, device=self.device),
1072+
torch.tensor([1.0, 1.0], dtype=dtype, device=self.device),
1073+
1.5,
1074+
)
1075+
]
1076+
# Batch of 3 infeasible points (all violate x[0] + x[1] >= 1.5)
1077+
# X must be 3D: batch x q x d when using tensor fixed_features
1078+
X = torch.tensor(
1079+
[
1080+
[[0.2, 0.3, 1.0]], # batch 0, q=1
1081+
[[0.4, 0.5, 0.5]], # batch 1, q=1
1082+
[[0.1, 0.2, 1.5]], # batch 2, q=1
1083+
],
1084+
dtype=dtype,
1085+
device=self.device,
1086+
) # Shape: [3, 1, 3]
1087+
# fixed_features = {0: tensor([0.2, 0.4, 0.1])} (different per batch)
1088+
fixed_values = torch.tensor(
1089+
[0.2, 0.4, 0.1], dtype=dtype, device=self.device
1090+
)
1091+
fixed_features: dict[int, float | torch.Tensor] = {0: fixed_values}
1092+
# Execute: project to feasible space with batched fixed_features
1093+
projected = project_to_feasible_space_via_slsqp(
1094+
X=X,
1095+
bounds=bounds,
1096+
inequality_constraints=inequality_constraints,
1097+
fixed_features=fixed_features,
1098+
)
1099+
# Assert: each batch element preserves its respective fixed value for x[0]
1100+
self.assertAllClose(projected[:, 0, 0], fixed_values)
1101+
# Assert: constraint is satisfied for each batch element
1102+
for i in range(3):
1103+
self.assertGreaterEqual(
1104+
(projected[i, 0, 0] + projected[i, 0, 1]).item(), 1.5 - tol
1105+
)
1106+
# Assert: bounds are respected
1107+
self.assertTrue(torch.all(projected >= bounds[0] - tol))
1108+
self.assertTrue(torch.all(projected <= bounds[1] + tol))

0 commit comments

Comments
 (0)