Skip to content

Commit 390b99c

Browse files
David Erikssonfacebook-github-bot
authored andcommitted
Validate candidates generated by Acquisition.optimize (facebook#4972)
Summary: This adds validation to `Acquisition.optimize` to make sure the generated candidates satisfy the bounds, values specified by choice parameters, and potential parameter constraints. This will allow us to catch potential downstream errors in the acquisition function optimization in BoTorch. Reviewed By: saitcakmak Differential Revision: D95078354
1 parent 4e2a2d4 commit 390b99c

File tree

3 files changed

+354
-76
lines changed

3 files changed

+354
-76
lines changed

ax/generators/torch/botorch_modular/acquisition.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ax.generators.torch.botorch_modular.utils import (
3232
_fix_map_key_to_target,
3333
_objective_threshold_to_outcome_constraints,
34+
validate_candidates,
3435
)
3536
from ax.generators.torch.botorch_moo_utils import infer_objective_thresholds
3637
from ax.generators.torch.utils import (
@@ -742,6 +743,17 @@ def optimize(
742743
X_avoid=X_observed,
743744
**optimizer_options_with_defaults,
744745
)
746+
# Validate candidates before returning
747+
validate_candidates(
748+
candidates=candidates,
749+
bounds=bounds,
750+
discrete_choices=ssd.discrete_choices
751+
if ssd.discrete_choices
752+
else None,
753+
inequality_constraints=inequality_constraints,
754+
feature_names=ssd.feature_names,
755+
task_features=ssd.task_features,
756+
)
745757
n_candidates = candidates.shape[0]
746758
return (
747759
candidates,
@@ -858,6 +870,16 @@ def optimize(
858870
inequality_constraints=inequality_constraints,
859871
fixed_features=fixed_features,
860872
)
873+
# Validate candidates before returning
874+
validate_candidates(
875+
candidates=candidates,
876+
bounds=bounds,
877+
discrete_choices=discrete_choices if discrete_choices else None,
878+
inequality_constraints=inequality_constraints,
879+
feature_names=search_space_digest.feature_names,
880+
task_features=search_space_digest.task_features,
881+
)
882+
861883
n_candidates = candidates.shape[0]
862884
return candidates, acqf_values, arm_weights[:n_candidates] * n_candidates / n
863885

ax/generators/torch/botorch_modular/utils.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
qLogNoisyExpectedHypervolumeImprovement,
3535
)
3636
from botorch.acquisition.multi_objective.parego import qLogNParEGO
37+
from botorch.exceptions.errors import BotorchError, CandidateGenerationError
3738
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
3839
from botorch.models import PairwiseLaplaceMarginalLogLikelihood
3940
from botorch.models.fully_bayesian import (
@@ -52,6 +53,11 @@
5253
from botorch.models.pairwise_gp import PairwiseGP
5354
from botorch.models.transforms.input import InputTransform, Normalize
5455
from botorch.models.transforms.outcome import OutcomeTransform
56+
from botorch.optim.parameter_constraints import (
57+
evaluate_feasibility,
58+
get_constraint_tolerance,
59+
)
60+
from botorch.optim.utils import columnwise_clamp
5561
from botorch.utils.constraints import get_outcome_constraint_transforms
5662
from botorch.utils.datasets import MultiTaskDataset, RankingDataset, SupervisedDataset
5763
from botorch.utils.dispatcher import Dispatcher
@@ -874,3 +880,90 @@ def get_all_task_values_from_ssd(search_space_digest: SearchSpaceDigest) -> list
874880
task_feature = search_space_digest.task_features[0]
875881
task_bounds = search_space_digest.bounds[task_feature]
876882
return list(range(int(task_bounds[0]), int(task_bounds[1] + 1)))
883+
884+
885+
def _format_discrete_value(val: float, allowed_values: Sequence[float]) -> str:
886+
"""Format a discrete value for display alongside allowed values.
887+
888+
If all allowed values are integers, formats val as int (via rounding).
889+
Otherwise formats as float with 4 decimal places.
890+
"""
891+
if all(float(v).is_integer() for v in allowed_values):
892+
return str(int(round(val)))
893+
return f"{val:.4f}"
894+
895+
896+
def validate_candidates(
897+
candidates: Tensor,
898+
bounds: Tensor,
899+
discrete_choices: Mapping[int, Sequence[float]] | None,
900+
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None,
901+
feature_names: list[str] | None = None,
902+
task_features: list[int] | None = None,
903+
) -> None:
904+
"""Validate candidates satisfy bounds, discrete, and inequality constraints.
905+
906+
Args:
907+
candidates: A `n x d`-dim Tensor of candidates to validate.
908+
bounds: A `2 x d`-dim Tensor of lower and upper bounds.
909+
discrete_choices: A mapping from parameter indices to allowed discrete values.
910+
inequality_constraints: A list of tuples (indices, coefficients, rhs),
911+
representing inequality constraints of the form
912+
`sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
913+
feature_names: Optional list of feature names for better error messages.
914+
task_features: Optional list of task feature indices to skip discrete value
915+
validation for. Task features can be fixed to new task values via
916+
fixed_features that are not in the search space's discrete_choices.
917+
918+
Raises:
919+
CandidateGenerationError: If any candidate violates constraints.
920+
"""
921+
922+
# 1. Bounds validation
923+
try:
924+
columnwise_clamp(
925+
candidates, lower=bounds[0], upper=bounds[1], raise_on_violation=True
926+
)
927+
except BotorchError as e:
928+
raise CandidateGenerationError(f"Candidate violates bounds: {e}")
929+
930+
# 2. Discrete value validation (sk
931+
task_features_set = set(task_features) if task_features else set()
932+
if discrete_choices:
933+
tol = get_constraint_tolerance(candidates.dtype)
934+
for dim, allowed_values in discrete_choices.items():
935+
# Skip task features as they can be fixed to new task values via
936+
# fixed_features that are not in the search space's discrete_choices
937+
if dim in task_features_set:
938+
continue
939+
allowed = torch.tensor(
940+
allowed_values, device=candidates.device, dtype=candidates.dtype
941+
)
942+
candidate_vals = candidates[..., dim].flatten()
943+
# Vectorized check: (num_candidates, num_allowed) -> any match per candidate
944+
is_valid = torch.isclose(
945+
candidate_vals.unsqueeze(-1), allowed.unsqueeze(0), atol=tol
946+
).any(dim=-1)
947+
if not is_valid.all():
948+
invalid_idx = int(torch.where(~is_valid)[0][0].item())
949+
val_float = candidate_vals[invalid_idx].item()
950+
dim_name = feature_names[dim] if feature_names else f"dim {dim}"
951+
raise CandidateGenerationError(
952+
f"Invalid discrete value "
953+
f"{_format_discrete_value(val_float, allowed_values)} for "
954+
f"{dim_name}. Allowed: {list(allowed_values)}"
955+
)
956+
957+
# 3. Inequality constraint validation
958+
if inequality_constraints:
959+
is_feasible = evaluate_feasibility(
960+
X=candidates.unsqueeze(-2), # Add q dimension
961+
inequality_constraints=inequality_constraints,
962+
)
963+
if not is_feasible.all():
964+
infeasible_indices = torch.where(~is_feasible)[0].tolist()
965+
raise CandidateGenerationError(
966+
f"Candidates violate inequality constraints. "
967+
f"Infeasible candidate indices: {infeasible_indices}. "
968+
f"Number of constraints: {len(inequality_constraints)}."
969+
)

0 commit comments

Comments
 (0)