Skip to content

Commit 58cff45

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
Wire equality constraints through generator and acquisition to BoTorch (facebook#5177)
Summary: Pull Request resolved: facebook#5177 Thread equality constraints from TorchOptConfig through to BoTorch's optimize_acqf and related optimizers. This is the key diff that connects Ax's equality constraint representation to BoTorch's SLSQP-based optimizer. - Add `_to_equality_constraints` in `torch/utils.py` — converts (A, b) tensor format to BoTorch's `(indices, coefficients, rhs)` format. No sign negation needed (equality is symmetric). - Update `BoTorchGenerator.gen()` to pass equality constraints to `acqf.optimize()`. - Add `equality_constraints` parameter to `Acquisition.optimize()` and forward to `optimize_acqf`, `optimize_acqf_mixed`, `optimize_acqf_mixed_alternating`. - Raise `ValueError` for discrete optimizers and NSGA-II (unsupported). - Update `validate_candidates` to check equality constraints. - Update `_prune_irrelevant_parameters` and `_remove_infeasible_candidates` to handle equality constraints. - Update `Surrogate.best_point` to pass equality constraints. Reviewed By: bletham Differential Revision: D100256482 fbshipit-source-id: 7739880ce77a3791587044a1d5c046c59c15e43d
1 parent b668431 commit 58cff45

8 files changed

Lines changed: 348 additions & 31 deletions

File tree

ax/generators/torch/botorch_modular/acquisition.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ def optimize(
596596
n: int,
597597
search_space_digest: SearchSpaceDigest,
598598
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
599+
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
599600
fixed_features: dict[int, float] | None = None,
600601
rounding_func: Callable[[Tensor], Tensor] | None = None,
601602
optimizer_options: dict[str, Any] | None = None,
@@ -612,6 +613,9 @@ def optimize(
612613
inequality_constraints: A list of tuples (indices, coefficients, rhs),
613614
with each tuple encoding an inequality constraint of the form
614615
``sum_i (X[indices[i]] * coefficients[i]) >= rhs``.
616+
equality_constraints: A list of tuples (indices, coefficients, rhs),
617+
with each tuple encoding an equality constraint of the form
618+
``sum_i (X[indices[i]] * coefficients[i]) = rhs``.
615619
fixed_features: A map `{feature_index: value}` for features that
616620
should be fixed to a particular value during generation.
617621
rounding_func: A function that post-processes an optimization
@@ -664,8 +668,8 @@ def optimize(
664668
# Ax expects `optimize_acqf` to return tensors of a certain shape.
665669
if optimizer_options is not None:
666670
forbidden_optimizer_options = [
667-
"equality_constraints",
668-
"inequality_constraints", # These should be constructed by Ax
671+
"equality_constraints", # Constructed by Ax
672+
"inequality_constraints", # Constructed by Ax
669673
"batch_initial_conditions",
670674
"return_best_only",
671675
"return_full_tree",
@@ -716,6 +720,7 @@ def optimize(
716720
bounds=bounds,
717721
q=n,
718722
inequality_constraints=inequality_constraints,
723+
equality_constraints=equality_constraints,
719724
fixed_features=fixed_features,
720725
post_processing_func=rounding_func,
721726
acq_function_sequence=self.acq_function_sequence,
@@ -727,6 +732,11 @@ def optimize(
727732
"optimize_acqf_discrete",
728733
"optimize_acqf_discrete_local_search",
729734
):
735+
if equality_constraints:
736+
raise ValueError(
737+
"Equality constraints are not supported with discrete "
738+
f"optimizer '{optimizer}'."
739+
)
730740
X_observed = self.X_observed
731741
if self.X_pending is not None:
732742
if X_observed is None:
@@ -805,6 +815,7 @@ def optimize(
805815
discrete_choices=discrete_choices
806816
),
807817
inequality_constraints=inequality_constraints,
818+
equality_constraints=equality_constraints,
808819
post_processing_func=rounding_func,
809820
**optimizer_options_with_defaults,
810821
)
@@ -832,9 +843,15 @@ def optimize(
832843
post_processing_func=rounding_func,
833844
fixed_features=fixed_features,
834845
inequality_constraints=inequality_constraints,
846+
equality_constraints=equality_constraints,
835847
**optimizer_options_with_defaults,
836848
)
837849
elif optimizer == "optimize_with_nsgaii":
850+
if equality_constraints:
851+
raise ValueError(
852+
"Equality constraints are not supported with "
853+
"optimizer 'optimize_with_nsgaii'."
854+
)
838855
if optimize_with_nsgaii is not None:
839856
acqf = assert_is_instance(
840857
self.acqf, MultiOutputAcquisitionFunctionWrapper
@@ -873,6 +890,7 @@ def optimize(
873890
candidates=candidates,
874891
search_space_digest=search_space_digest,
875892
inequality_constraints=inequality_constraints,
893+
equality_constraints=equality_constraints,
876894
fixed_features=fixed_features,
877895
)
878896
# Validate candidates before returning
@@ -883,6 +901,7 @@ def optimize(
883901
inequality_constraints=inequality_constraints,
884902
feature_names=search_space_digest.feature_names,
885903
task_features=search_space_digest.task_features,
904+
equality_constraints=equality_constraints,
886905
)
887906

888907
n_candidates = candidates.shape[0]
@@ -986,6 +1005,7 @@ def _prune_irrelevant_parameters(
9861005
candidates: Tensor,
9871006
search_space_digest: SearchSpaceDigest,
9881007
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
1008+
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
9891009
fixed_features: dict[int, float] | None = None,
9901010
) -> tuple[Tensor, Tensor]:
9911011
r"""Prune irrelevant parameters from the candidates using BONSAI.
@@ -1092,6 +1112,7 @@ def _prune_irrelevant_parameters(
10921112
candidates=pruned_candidates,
10931113
indices=indices,
10941114
inequality_constraints=inequality_constraints,
1115+
equality_constraints=equality_constraints,
10951116
)
10961117
if pruned_candidates.shape[0] == 0:
10971118
# no feasible points, continue to
@@ -1205,6 +1226,7 @@ def _remove_infeasible_candidates(
12051226
candidates: Tensor,
12061227
indices: Tensor,
12071228
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
1229+
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
12081230
) -> tuple[Tensor, Tensor]:
12091231
r"""Filter out infeasible candidates, based on the parameter constraints.
12101232
@@ -1214,24 +1236,19 @@ def _remove_infeasible_candidates(
12141236
in [0, d-1).
12151237
inequality_constraints: A list of tuples (indices, coefficients, rhs),
12161238
with each tuple encoding an inequality constraint of the form
1217-
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and
1218-
`coefficients` should be torch tensors. See the docstring of
1219-
`make_scipy_linear_constraints` for an example. When q=1, or when
1220-
applying the same constraint to each candidate in the batch
1221-
(intra-point constraint), `indices` should be a 1-d tensor.
1222-
For inter-point constraints, in which the constraint is applied to the
1223-
whole batch of candidates, `indices` must be a 2-d tensor, where
1224-
in each row `indices[i] =(k_i, l_i)` the first index `k_i` corresponds
1225-
to the `k_i`-th element of the `q`-batch and the second index `l_i`
1226-
corresponds to the `l_i`-th feature of that element.
1239+
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
1240+
equality_constraints: A list of tuples (indices, coefficients, rhs),
1241+
with each tuple encoding an equality constraint of the form
1242+
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
12271243
12281244
Returns:
1229-
A two-element tuple containing the filter candidates and indices.
1245+
A two-element tuple containing the filtered candidates and indices.
12301246
"""
1231-
if inequality_constraints is not None:
1247+
if inequality_constraints is not None or equality_constraints is not None:
12321248
is_feasible = evaluate_feasibility(
12331249
X=candidates,
12341250
inequality_constraints=inequality_constraints,
1251+
equality_constraints=equality_constraints,
12351252
)
12361253
candidates = candidates[is_feasible]
12371254
indices = indices[is_feasible]

ax/generators/torch/botorch_modular/generator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ModelConfig,
2727
)
2828
from ax.generators.torch.utils import (
29+
_to_equality_constraints,
2930
_to_inequality_constraints,
3031
get_feature_importances_from_botorch_model,
3132
get_rounding_func,
@@ -401,6 +402,9 @@ def gen(
401402
inequality_constraints=_to_inequality_constraints(
402403
linear_constraints=torch_opt_config.linear_constraints
403404
),
405+
equality_constraints=_to_equality_constraints(
406+
equality_constraints=torch_opt_config.equality_constraints
407+
),
404408
fixed_features=torch_opt_config.fixed_features,
405409
rounding_func=botorch_rounding_func,
406410
optimizer_options=assert_is_instance(

ax/generators/torch/botorch_modular/surrogate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
use_model_list,
4343
)
4444
from ax.generators.torch.utils import (
45+
_to_equality_constraints,
4546
_to_inequality_constraints,
4647
pick_best_out_of_sample_point_acqf_class,
4748
predict_from_model,
@@ -1059,6 +1060,9 @@ def best_out_of_sample_point(
10591060
inequality_constraints=_to_inequality_constraints(
10601061
linear_constraints=torch_opt_config.linear_constraints
10611062
),
1063+
equality_constraints=_to_equality_constraints(
1064+
equality_constraints=torch_opt_config.equality_constraints
1065+
),
10621066
fixed_features=torch_opt_config.fixed_features,
10631067
)
10641068
return candidates[0], acqf_value

ax/generators/torch/botorch_modular/utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,8 +892,9 @@ def validate_candidates(
892892
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None,
893893
feature_names: list[str] | None = None,
894894
task_features: list[int] | None = None,
895+
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
895896
) -> None:
896-
"""Validate candidates satisfy bounds, discrete, and inequality constraints.
897+
"""Validate candidates satisfy bounds, discrete, and linear constraints.
897898
898899
Args:
899900
candidates: A `n x d`-dim Tensor of candidates to validate.
@@ -906,6 +907,9 @@ def validate_candidates(
906907
task_features: Optional list of task feature indices to skip discrete value
907908
validation for. Task features can be fixed to new task values via
908909
fixed_features that are not in the search space's discrete_choices.
910+
equality_constraints: A list of tuples (indices, coefficients, rhs),
911+
representing equality constraints of the form
912+
`sum_i (X[indices[i]] * coefficients[i]) = rhs`.
909913
910914
Raises:
911915
CandidateGenerationError: If any candidate violates constraints.
@@ -959,3 +963,17 @@ def validate_candidates(
959963
f"Infeasible candidate indices: {infeasible_indices}. "
960964
f"Number of constraints: {len(inequality_constraints)}."
961965
)
966+
967+
# 4. Equality constraint validation
968+
if equality_constraints:
969+
is_feasible = evaluate_feasibility(
970+
X=candidates.unsqueeze(-2), # Add q dimension
971+
equality_constraints=equality_constraints,
972+
)
973+
if not is_feasible.all():
974+
infeasible_indices = torch.where(~is_feasible)[0].tolist()
975+
raise CandidateGenerationError(
976+
f"Candidates violate equality constraints. "
977+
f"Infeasible candidate indices: {infeasible_indices}. "
978+
f"Number of constraints: {len(equality_constraints)}."
979+
)

0 commit comments

Comments
 (0)