@@ -596,6 +596,7 @@ def optimize(
596596 n : int ,
597597 search_space_digest : SearchSpaceDigest ,
598598 inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
599+ equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
599600 fixed_features : dict [int , float ] | None = None ,
600601 rounding_func : Callable [[Tensor ], Tensor ] | None = None ,
601602 optimizer_options : dict [str , Any ] | None = None ,
@@ -612,6 +613,9 @@ def optimize(
612613 inequality_constraints: A list of tuples (indices, coefficients, rhs),
613614 with each tuple encoding an inequality constraint of the form
614615 ``sum_i (X[indices[i]] * coefficients[i]) >= rhs``.
616+ equality_constraints: A list of tuples (indices, coefficients, rhs),
617+ with each tuple encoding an equality constraint of the form
618+ ``sum_i (X[indices[i]] * coefficients[i]) = rhs``.
615619 fixed_features: A map `{feature_index: value}` for features that
616620 should be fixed to a particular value during generation.
617621 rounding_func: A function that post-processes an optimization
@@ -664,8 +668,8 @@ def optimize(
664668 # Ax expects `optimize_acqf` to return tensors of a certain shape.
665669 if optimizer_options is not None :
666670 forbidden_optimizer_options = [
667- "equality_constraints" ,
668- "inequality_constraints" , # These should be constructed by Ax
671+ "equality_constraints" , # Constructed by Ax
672+ "inequality_constraints" , # Constructed by Ax
669673 "batch_initial_conditions" ,
670674 "return_best_only" ,
671675 "return_full_tree" ,
@@ -716,6 +720,7 @@ def optimize(
716720 bounds = bounds ,
717721 q = n ,
718722 inequality_constraints = inequality_constraints ,
723+ equality_constraints = equality_constraints ,
719724 fixed_features = fixed_features ,
720725 post_processing_func = rounding_func ,
721726 acq_function_sequence = self .acq_function_sequence ,
@@ -727,6 +732,11 @@ def optimize(
727732 "optimize_acqf_discrete" ,
728733 "optimize_acqf_discrete_local_search" ,
729734 ):
735+ if equality_constraints :
736+ raise ValueError (
737+ "Equality constraints are not supported with discrete "
738+ f"optimizer '{ optimizer } '."
739+ )
730740 X_observed = self .X_observed
731741 if self .X_pending is not None :
732742 if X_observed is None :
@@ -805,6 +815,7 @@ def optimize(
805815 discrete_choices = discrete_choices
806816 ),
807817 inequality_constraints = inequality_constraints ,
818+ equality_constraints = equality_constraints ,
808819 post_processing_func = rounding_func ,
809820 ** optimizer_options_with_defaults ,
810821 )
@@ -832,9 +843,15 @@ def optimize(
832843 post_processing_func = rounding_func ,
833844 fixed_features = fixed_features ,
834845 inequality_constraints = inequality_constraints ,
846+ equality_constraints = equality_constraints ,
835847 ** optimizer_options_with_defaults ,
836848 )
837849 elif optimizer == "optimize_with_nsgaii" :
850+ if equality_constraints :
851+ raise ValueError (
852+ "Equality constraints are not supported with "
853+ "optimizer 'optimize_with_nsgaii'."
854+ )
838855 if optimize_with_nsgaii is not None :
839856 acqf = assert_is_instance (
840857 self .acqf , MultiOutputAcquisitionFunctionWrapper
@@ -873,6 +890,7 @@ def optimize(
873890 candidates = candidates ,
874891 search_space_digest = search_space_digest ,
875892 inequality_constraints = inequality_constraints ,
893+ equality_constraints = equality_constraints ,
876894 fixed_features = fixed_features ,
877895 )
878896 # Validate candidates before returning
@@ -883,6 +901,7 @@ def optimize(
883901 inequality_constraints = inequality_constraints ,
884902 feature_names = search_space_digest .feature_names ,
885903 task_features = search_space_digest .task_features ,
904+ equality_constraints = equality_constraints ,
886905 )
887906
888907 n_candidates = candidates .shape [0 ]
@@ -986,6 +1005,7 @@ def _prune_irrelevant_parameters(
9861005 candidates : Tensor ,
9871006 search_space_digest : SearchSpaceDigest ,
9881007 inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
1008+ equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
9891009 fixed_features : dict [int , float ] | None = None ,
9901010 ) -> tuple [Tensor , Tensor ]:
9911011 r"""Prune irrelevant parameters from the candidates using BONSAI.
@@ -1092,6 +1112,7 @@ def _prune_irrelevant_parameters(
10921112 candidates = pruned_candidates ,
10931113 indices = indices ,
10941114 inequality_constraints = inequality_constraints ,
1115+ equality_constraints = equality_constraints ,
10951116 )
10961117 if pruned_candidates .shape [0 ] == 0 :
10971118 # no feasible points, continue to
@@ -1205,6 +1226,7 @@ def _remove_infeasible_candidates(
12051226 candidates : Tensor ,
12061227 indices : Tensor ,
12071228 inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
1229+ equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
12081230) -> tuple [Tensor , Tensor ]:
12091231 r"""Filter out infeasible candidates, based on the parameter constraints.
12101232
@@ -1214,24 +1236,19 @@ def _remove_infeasible_candidates(
12141236 in [0, d-1).
12151237 inequality_constraints: A list of tuples (indices, coefficients, rhs),
12161238 with each tuple encoding an inequality constraint of the form
1217- `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and
1218- `coefficients` should be torch tensors. See the docstring of
1219- `make_scipy_linear_constraints` for an example. When q=1, or when
1220- applying the same constraint to each candidate in the batch
1221- (intra-point constraint), `indices` should be a 1-d tensor.
1222- For inter-point constraints, in which the constraint is applied to the
1223- whole batch of candidates, `indices` must be a 2-d tensor, where
1224- in each row `indices[i] =(k_i, l_i)` the first index `k_i` corresponds
1225- to the `k_i`-th element of the `q`-batch and the second index `l_i`
1226- corresponds to the `l_i`-th feature of that element.
1239+ `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
1240+ equality_constraints: A list of tuples (indices, coefficients, rhs),
1241+ with each tuple encoding an equality constraint of the form
1242+ `\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
12271243
12281244 Returns:
1229- A two-element tuple containing the filter candidates and indices.
1245+ A two-element tuple containing the filtered candidates and indices.
12301246 """
1231- if inequality_constraints is not None :
1247+ if inequality_constraints is not None or equality_constraints is not None :
12321248 is_feasible = evaluate_feasibility (
12331249 X = candidates ,
12341250 inequality_constraints = inequality_constraints ,
1251+ equality_constraints = equality_constraints ,
12351252 )
12361253 candidates = candidates [is_feasible ]
12371254 indices = indices [is_feasible ]
0 commit comments