Skip to content

Commit 26a1941

Browse files
David Erikssonfacebook-github-bot
authored andcommitted
Validate candidates generated by Acquisition.optimize
Summary: This adds validation to `Acquisition.optimize` to make sure the generated candidates satisfy the bounds, values specified by choice parameters, and potential parameter constraints. This will allow us to catch potential downstream errors in the acquisition function optimization in BoTorch. Differential Revision: D95078354
1 parent eb34abe commit 26a1941

File tree

3 files changed

+357
-74
lines changed

3 files changed

+357
-74
lines changed

ax/generators/torch/botorch_modular/acquisition.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ax.generators.torch.botorch_modular.utils import (
3232
_fix_map_key_to_target,
3333
_objective_threshold_to_outcome_constraints,
34+
validate_candidates,
3435
)
3536
from ax.generators.torch.botorch_moo_utils import infer_objective_thresholds
3637
from ax.generators.torch.utils import (
@@ -742,6 +743,17 @@ def optimize(
742743
X_avoid=X_observed,
743744
**optimizer_options_with_defaults,
744745
)
746+
# Validate candidates before returning
747+
validate_candidates(
748+
candidates=candidates,
749+
bounds=bounds,
750+
discrete_choices=ssd.discrete_choices
751+
if ssd.discrete_choices
752+
else None,
753+
inequality_constraints=inequality_constraints,
754+
feature_names=ssd.feature_names,
755+
task_features=ssd.task_features,
756+
)
745757
n_candidates = candidates.shape[0]
746758
return (
747759
candidates,
@@ -856,6 +868,16 @@ def optimize(
856868
inequality_constraints=inequality_constraints,
857869
fixed_features=fixed_features,
858870
)
871+
# Validate candidates before returning
872+
validate_candidates(
873+
candidates=candidates,
874+
bounds=bounds,
875+
discrete_choices=discrete_choices if discrete_choices else None,
876+
inequality_constraints=inequality_constraints,
877+
feature_names=search_space_digest.feature_names,
878+
task_features=search_space_digest.task_features,
879+
)
880+
859881
n_candidates = candidates.shape[0]
860882
return candidates, acqf_values, arm_weights[:n_candidates] * n_candidates / n
861883

ax/generators/torch/botorch_modular/utils.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
qLogNoisyExpectedHypervolumeImprovement,
3535
)
3636
from botorch.acquisition.multi_objective.parego import qLogNParEGO
37+
from botorch.exceptions.errors import BotorchError, CandidateGenerationError
3738
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
3839
from botorch.models import PairwiseLaplaceMarginalLogLikelihood
3940
from botorch.models.fully_bayesian import (
@@ -52,6 +53,11 @@
5253
from botorch.models.pairwise_gp import PairwiseGP
5354
from botorch.models.transforms.input import InputTransform, Normalize
5455
from botorch.models.transforms.outcome import OutcomeTransform
56+
from botorch.optim.parameter_constraints import (
57+
evaluate_feasibility,
58+
get_constraint_tolerance,
59+
)
60+
from botorch.optim.utils import columnwise_clamp
5561
from botorch.utils.constraints import get_outcome_constraint_transforms
5662
from botorch.utils.datasets import MultiTaskDataset, RankingDataset, SupervisedDataset
5763
from botorch.utils.dispatcher import Dispatcher
@@ -874,3 +880,93 @@ def get_all_task_values_from_ssd(search_space_digest: SearchSpaceDigest) -> list
874880
task_feature = search_space_digest.task_features[0]
875881
task_bounds = search_space_digest.bounds[task_feature]
876882
return list(range(int(task_bounds[0]), int(task_bounds[1] + 1)))
883+
884+
885+
def _format_discrete_value(val: float, allowed_values: Sequence[float]) -> str:
886+
"""Format a discrete value for display alongside allowed values.
887+
888+
If all allowed values are integers, formats val as int (via rounding).
889+
Otherwise formats as float with 4 decimal places.
890+
"""
891+
if all(float(v).is_integer() for v in allowed_values):
892+
return str(int(round(val)))
893+
return f"{val:.4f}"
894+
895+
896+
def validate_candidates(
897+
candidates: Tensor,
898+
bounds: Tensor,
899+
discrete_choices: Mapping[int, Sequence[float]] | None,
900+
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None,
901+
feature_names: list[str] | None = None,
902+
task_features: list[int] | None = None,
903+
) -> None:
904+
"""Validate candidates satisfy bounds, discrete, and inequality constraints.
905+
906+
Args:
907+
candidates: A `n x d`-dim Tensor of candidates to validate.
908+
bounds: A `2 x d`-dim Tensor of lower and upper bounds.
909+
discrete_choices: A mapping from parameter indices to allowed discrete values.
910+
inequality_constraints: A list of tuples (indices, coefficients, rhs),
911+
representing inequality constraints of the form
912+
`sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
913+
feature_names: Optional list of feature names for better error messages.
914+
task_features: Optional list of task feature indices to skip discrete value
915+
validation for. Task features can have values for new tasks that are
916+
not in the training data and are handled separately via fixed_features.
917+
918+
Raises:
919+
CandidateGenerationError: If any candidate violates constraints.
920+
"""
921+
922+
# 1. Bounds validation
923+
try:
924+
columnwise_clamp(
925+
candidates, lower=bounds[0], upper=bounds[1], raise_on_violation=True
926+
)
927+
except BotorchError as e:
928+
raise CandidateGenerationError(f"Candidate violates bounds: {e}")
929+
930+
# 2. Discrete value validation (skip task features)
931+
# Use rounding to match Ax's casting behavior: IntToFloat.untransform uses
932+
# int(round(value)), so we round candidates before checking against allowed
933+
# values. This ensures validation matches actual post-transform values.
934+
task_features_set = set(task_features) if task_features else set()
935+
if discrete_choices:
936+
tol = get_constraint_tolerance(candidates.dtype)
937+
for dim, allowed_values in discrete_choices.items():
938+
# Skip task features as they can have values for new tasks not in
939+
# training data and are handled separately via fixed_features
940+
if dim in task_features_set:
941+
continue
942+
allowed = torch.tensor(
943+
allowed_values, device=candidates.device, dtype=candidates.dtype
944+
)
945+
candidate_vals = candidates[..., dim].flatten()
946+
# Vectorized check: (num_candidates, num_allowed) -> any match per candidate
947+
is_valid = torch.isclose(
948+
candidate_vals.unsqueeze(-1), allowed.unsqueeze(0), atol=tol
949+
).any(dim=-1)
950+
if not is_valid.all():
951+
invalid_idx = int(torch.where(~is_valid)[0][0].item())
952+
val_float = candidate_vals[invalid_idx].item()
953+
dim_name = feature_names[dim] if feature_names else f"dim {dim}"
954+
raise CandidateGenerationError(
955+
f"Invalid discrete value "
956+
f"{_format_discrete_value(val_float, allowed_values)} for "
957+
f"{dim_name}. Allowed: {list(allowed_values)}"
958+
)
959+
960+
# 3. Inequality constraint validation
961+
if inequality_constraints:
962+
is_feasible = evaluate_feasibility(
963+
X=candidates.unsqueeze(-2), # Add q dimension
964+
inequality_constraints=inequality_constraints,
965+
)
966+
if not is_feasible.all():
967+
infeasible_indices = torch.where(~is_feasible)[0].tolist()
968+
raise CandidateGenerationError(
969+
f"Candidates violate inequality constraints. "
970+
f"Infeasible candidate indices: {infeasible_indices}. "
971+
f"Number of constraints: {len(inequality_constraints)}."
972+
)

0 commit comments

Comments
 (0)