Skip to content

Commit 8569f4f

Browse files
ItsMrLinfacebook-github-bot
authored andcommitted
Allow continuous relaxation of inequality-constrained ordinal dims
Summary: `_setup_continuous_relaxation` in `optimize_mixed.py` blanket-excludes all constrained discrete dimensions from continuous relaxation, forcing them into discrete local search even when they have high cardinality. This is overly conservative for inequality constraints and causes severe performance degradation. **Problem:** When ordinal parameters (e.g., integers 0-50) participate in linear inequality constraints (e.g., `x1 + x2 + x3 <= 100`), they are kept as discrete dims regardless of cardinality. In mixed search spaces, this inflates the discrete combination count (e.g., 51^4 x 20 = 135M), forces `optimize_acqf_mixed_alternating`, and with default optimizer budgets (`raw_samples=1024`, `maxiter_init=100`, `maxiter_alternating=64`) across many sequential candidates, produces ~900K+ acquisition function evaluations -- taking hours instead of minutes. **Fix:** Only exclude dimensions participating in **equality** constraints from continuous relaxation. Inequality-constrained dims are now eligible for relaxation because `_optimize_acqf_batch` already handles rounding-induced constraint violations: it applies `post_processing_func` (rounding), checks feasibility, projects infeasible candidates back via SLSQP, and re-rounds (see `optimize.py` lines 528-574). **Why this is safe for inequality but not equality constraints:** - *Inequality*: The feasible region has interior. SLSQP easily finds a nearby feasible point after rounding, and re-rounding a near-integer point rarely re-violates. - *Equality*: The feasible set is a hyperplane. Rounding pushes points off the hyperplane, SLSQP projects back (to a non-integer point), re-rounding pushes off again -- the cycle may not converge. Differential Revision: D99304800
1 parent b16b28f commit 8569f4f

2 files changed

Lines changed: 67 additions & 20 deletions

File tree

botorch/optim/optimize_mixed.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,22 @@ def _setup_continuous_relaxation(
130130
``discrete_dims`` and ``post_processing_func`` is updated to round
131131
them to the nearest integer.
132132
133-
Dimensions that participate in constraints are NOT relaxed, as rounding
134-
after projection could violate those constraints.
133+
Dimensions that participate in equality constraints are NOT relaxed,
134+
as rounding can push points off the constraint hyperplane and the
135+
project-round cycle may not converge. Dimensions in inequality
136+
constraints ARE eligible for relaxation because the feasible region
137+
has interior and ``_optimize_acqf_batch`` already handles any
138+
rounding-induced violations via SLSQP projection and re-rounding.
135139
"""
136140

137-
# Identify dimensions involved in constraints
141+
# Only exclude dimensions involved in equality constraints.
142+
# Inequality-constrained dims can be safely relaxed: if rounding
143+
# violates the constraint, _optimize_acqf_batch projects back via
144+
# SLSQP and re-rounds (see optimize.py lines 528-574).
138145
constrained_dims: set[int] = set()
139-
for constraints in [inequality_constraints, equality_constraints]:
140-
if constraints is not None:
141-
for indices, _, _ in constraints:
142-
constrained_dims.update(indices.tolist())
146+
if equality_constraints is not None:
147+
for indices, _, _ in equality_constraints:
148+
constrained_dims.update(indices.tolist())
143149

144150
dims_to_relax, dims_to_keep = {}, {}
145151
for index, values in discrete_dims.items():

test/optim/test_optimize_mixed.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,55 +1612,96 @@ def org_post_proc_func(X: Tensor) -> Tensor:
16121612
self.assertAllClose(X[..., all_integer_dims], X[..., all_integer_dims].round())
16131613

16141614
def test_setup_continuous_relaxation_excludes_constrained_dims(self) -> None:
1615-
"""Test that _setup_continuous_relaxation keeps constrained discrete dims."""
1615+
"""Test that _setup_continuous_relaxation only excludes equality-constrained
1616+
dims from relaxation. Inequality-constrained dims are eligible for
1617+
relaxation since _optimize_acqf_batch handles rounding-induced violations
1618+
via SLSQP projection.
1619+
"""
16161620
for dtype in (torch.float, torch.double):
16171621
# Setup: 3 discrete dimensions
16181622
# - Dim 0: Low cardinality (2 values) - kept regardless
1619-
# - Dim 1: High cardinality (50 values), participates in constraint - kept
1623+
# - Dim 1: High cardinality (50 values), inequality constraint - relaxed
16201624
# - Dim 2: High cardinality (50 values), not constrained - relaxed
16211625
discrete_dims: dict[int, list[float]] = {
16221626
0: [0.0, 1.0], # Low cardinality - should be kept
1623-
1: list(range(50)), # High cardinality, constrained - should be kept
1627+
1: list(range(50)), # High card, ineq constrained - relaxed
16241628
2: list(range(50)), # High cardinality, not constrained - relaxed
16251629
}
16261630
max_discrete_values = 20
1627-
# Constraint on dim 1: x[1] >= 10
1631+
# Inequality constraint on dim 1: x[1] >= 10
16281632
inequality_constraints = [
16291633
(
16301634
torch.tensor([1], dtype=torch.long, device=self.device),
16311635
torch.tensor([1.0], dtype=dtype, device=self.device),
16321636
10.0,
16331637
)
16341638
]
1635-
# Execute: call _setup_continuous_relaxation
16361639
dims_kept, post_processing_func = _setup_continuous_relaxation(
16371640
discrete_dims=discrete_dims,
16381641
max_discrete_values=max_discrete_values,
16391642
post_processing_func=None,
16401643
inequality_constraints=inequality_constraints,
16411644
)
1642-
# Assert: dims 0 and 1 are kept (low cardinality and constrained)
1645+
# Dim 0 kept (low cardinality), dims 1 and 2 relaxed
16431646
self.assertIn(0, dims_kept)
1644-
self.assertIn(1, dims_kept)
1645-
# Assert: dim 2 is NOT in dims_kept (relaxed)
1647+
self.assertNotIn(1, dims_kept)
16461648
self.assertNotIn(2, dims_kept)
1647-
# Assert: post_processing_func is not None since dim 2 was relaxed
16481649
self.assertIsNotNone(post_processing_func)
1649-
# Assert: post_processing_func rounds dim 2 but not dims 0 or 1
1650+
# post_processing_func rounds dims 1 and 2 but not dim 0
16501651
X = torch.tensor(
1651-
[0.4, 25.3, 30.7], # dim 0, 1, 2 with non-integer values
1652+
[0.4, 25.3, 30.7],
16521653
dtype=dtype,
16531654
device=self.device,
16541655
)
16551656
X_processed = post_processing_func(X)
1656-
# Dim 0 and 1 should remain unchanged (not rounded by this func)
16571657
self.assertAllClose(
16581658
X_processed[0], torch.tensor(0.4, dtype=dtype, device=self.device)
16591659
)
1660+
self.assertAllClose(
1661+
X_processed[1], torch.tensor(25.0, dtype=dtype, device=self.device)
1662+
)
1663+
self.assertAllClose(
1664+
X_processed[2], torch.tensor(31.0, dtype=dtype, device=self.device)
1665+
)
1666+
1667+
def test_setup_continuous_relaxation_excludes_equality_constrained(self) -> None:
1668+
"""Test that equality-constrained dims are excluded from relaxation."""
1669+
for dtype in (torch.float, torch.double):
1670+
discrete_dims: dict[int, list[float]] = {
1671+
0: [0.0, 1.0], # Low cardinality - kept
1672+
1: list(range(50)), # High card, equality constrained - kept
1673+
2: list(range(50)), # High card, unconstrained - relaxed
1674+
}
1675+
max_discrete_values = 20
1676+
equality_constraints = [
1677+
(
1678+
torch.tensor([1], dtype=torch.long, device=self.device),
1679+
torch.tensor([1.0], dtype=dtype, device=self.device),
1680+
25.0,
1681+
)
1682+
]
1683+
dims_kept, post_processing_func = _setup_continuous_relaxation(
1684+
discrete_dims=discrete_dims,
1685+
max_discrete_values=max_discrete_values,
1686+
post_processing_func=None,
1687+
equality_constraints=equality_constraints,
1688+
)
1689+
# Dims 0 and 1 kept; dim 2 relaxed
1690+
self.assertIn(0, dims_kept)
1691+
self.assertIn(1, dims_kept)
1692+
self.assertNotIn(2, dims_kept)
1693+
self.assertIsNotNone(post_processing_func)
1694+
X = torch.tensor(
1695+
[0.4, 25.3, 30.7],
1696+
dtype=dtype,
1697+
device=self.device,
1698+
)
1699+
X_processed = post_processing_func(X)
1700+
# Dim 1 should NOT be rounded (equality-constrained, kept discrete)
16601701
self.assertAllClose(
16611702
X_processed[1], torch.tensor(25.3, dtype=dtype, device=self.device)
16621703
)
1663-
# Dim 2 should be rounded to nearest valid value
1704+
# Dim 2 should be rounded
16641705
self.assertAllClose(
16651706
X_processed[2], torch.tensor(31.0, dtype=dtype, device=self.device)
16661707
)

0 commit comments

Comments
 (0)