|
34 | 34 | qLogNoisyExpectedHypervolumeImprovement, |
35 | 35 | ) |
36 | 36 | from botorch.acquisition.multi_objective.parego import qLogNParEGO |
| 37 | +from botorch.exceptions.errors import BotorchError, CandidateGenerationError |
37 | 38 | from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll |
38 | 39 | from botorch.models import PairwiseLaplaceMarginalLogLikelihood |
39 | 40 | from botorch.models.fully_bayesian import ( |
|
52 | 53 | from botorch.models.pairwise_gp import PairwiseGP |
53 | 54 | from botorch.models.transforms.input import InputTransform, Normalize |
54 | 55 | from botorch.models.transforms.outcome import OutcomeTransform |
| 56 | +from botorch.optim.parameter_constraints import ( |
| 57 | + evaluate_feasibility, |
| 58 | + get_constraint_tolerance, |
| 59 | +) |
| 60 | +from botorch.optim.utils import columnwise_clamp |
55 | 61 | from botorch.utils.constraints import get_outcome_constraint_transforms |
56 | 62 | from botorch.utils.datasets import MultiTaskDataset, RankingDataset, SupervisedDataset |
57 | 63 | from botorch.utils.dispatcher import Dispatcher |
@@ -874,3 +880,93 @@ def get_all_task_values_from_ssd(search_space_digest: SearchSpaceDigest) -> list |
874 | 880 | task_feature = search_space_digest.task_features[0] |
875 | 881 | task_bounds = search_space_digest.bounds[task_feature] |
876 | 882 | return list(range(int(task_bounds[0]), int(task_bounds[1] + 1))) |
| 883 | + |
| 884 | + |
| 885 | +def _format_discrete_value(val: float, allowed_values: Sequence[float]) -> str: |
| 886 | + """Format a discrete value for display alongside allowed values. |
| 887 | +
|
| 888 | + If all allowed values are integers, formats val as int (via rounding). |
| 889 | + Otherwise formats as float with 4 decimal places. |
| 890 | + """ |
| 891 | + if all(float(v).is_integer() for v in allowed_values): |
| 892 | + return str(int(round(val))) |
| 893 | + return f"{val:.4f}" |
| 894 | + |
| 895 | + |
| 896 | +def validate_candidates( |
| 897 | + candidates: Tensor, |
| 898 | + bounds: Tensor, |
| 899 | + discrete_choices: Mapping[int, Sequence[float]] | None, |
| 900 | + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None, |
| 901 | + feature_names: list[str] | None = None, |
| 902 | + task_features: list[int] | None = None, |
| 903 | +) -> None: |
| 904 | + """Validate candidates satisfy bounds, discrete, and inequality constraints. |
| 905 | +
|
| 906 | + Args: |
| 907 | + candidates: A `n x d`-dim Tensor of candidates to validate. |
| 908 | + bounds: A `2 x d`-dim Tensor of lower and upper bounds. |
| 909 | + discrete_choices: A mapping from parameter indices to allowed discrete values. |
| 910 | + inequality_constraints: A list of tuples (indices, coefficients, rhs), |
| 911 | + representing inequality constraints of the form |
| 912 | + `sum_i (X[indices[i]] * coefficients[i]) >= rhs`. |
| 913 | + feature_names: Optional list of feature names for better error messages. |
| 914 | + task_features: Optional list of task feature indices to skip discrete value |
| 915 | + validation for. Task features can have values for new tasks that are |
| 916 | + not in the training data and are handled separately via fixed_features. |
| 917 | +
|
| 918 | + Raises: |
| 919 | + CandidateGenerationError: If any candidate violates constraints. |
| 920 | + """ |
| 921 | + |
| 922 | + # 1. Bounds validation |
| 923 | + try: |
| 924 | + columnwise_clamp( |
| 925 | + candidates, lower=bounds[0], upper=bounds[1], raise_on_violation=True |
| 926 | + ) |
| 927 | + except BotorchError as e: |
| 928 | + raise CandidateGenerationError(f"Candidate violates bounds: {e}") |
| 929 | + |
| 930 | + # 2. Discrete value validation (skip task features) |
| 931 | + # Use rounding to match Ax's casting behavior: IntToFloat.untransform uses |
| 932 | + # int(round(value)), so we round candidates before checking against allowed |
| 933 | + # values. This ensures validation matches actual post-transform values. |
| 934 | + task_features_set = set(task_features) if task_features else set() |
| 935 | + if discrete_choices: |
| 936 | + tol = get_constraint_tolerance(candidates.dtype) |
| 937 | + for dim, allowed_values in discrete_choices.items(): |
| 938 | + # Skip task features as they can have values for new tasks not in |
| 939 | + # training data and are handled separately via fixed_features |
| 940 | + if dim in task_features_set: |
| 941 | + continue |
| 942 | + allowed = torch.tensor( |
| 943 | + allowed_values, device=candidates.device, dtype=candidates.dtype |
| 944 | + ) |
| 945 | + candidate_vals = candidates[..., dim].flatten() |
| 946 | + # Vectorized check: (num_candidates, num_allowed) -> any match per candidate |
| 947 | + is_valid = torch.isclose( |
| 948 | + candidate_vals.unsqueeze(-1), allowed.unsqueeze(0), atol=tol |
| 949 | + ).any(dim=-1) |
| 950 | + if not is_valid.all(): |
| 951 | + invalid_idx = int(torch.where(~is_valid)[0][0].item()) |
| 952 | + val_float = candidate_vals[invalid_idx].item() |
| 953 | + dim_name = feature_names[dim] if feature_names else f"dim {dim}" |
| 954 | + raise CandidateGenerationError( |
| 955 | + f"Invalid discrete value " |
| 956 | + f"{_format_discrete_value(val_float, allowed_values)} for " |
| 957 | + f"{dim_name}. Allowed: {list(allowed_values)}" |
| 958 | + ) |
| 959 | + |
| 960 | + # 3. Inequality constraint validation |
| 961 | + if inequality_constraints: |
| 962 | + is_feasible = evaluate_feasibility( |
| 963 | + X=candidates.unsqueeze(-2), # Add q dimension |
| 964 | + inequality_constraints=inequality_constraints, |
| 965 | + ) |
| 966 | + if not is_feasible.all(): |
| 967 | + infeasible_indices = torch.where(~is_feasible)[0].tolist() |
| 968 | + raise CandidateGenerationError( |
| 969 | + f"Candidates violate inequality constraints. " |
| 970 | + f"Infeasible candidate indices: {infeasible_indices}. " |
| 971 | + f"Number of constraints: {len(inequality_constraints)}." |
| 972 | + ) |
0 commit comments