From 0b5214b8b74e73665a0babde46f9e1ae24f72a06 Mon Sep 17 00:00:00 2001 From: David Eriksson Date: Wed, 4 Mar 2026 18:37:05 -0800 Subject: [PATCH] Use discrete_values and post_processing_func in optimize_with_nsgaii (#4971) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/4971 `optimize_with_nsgaii` may produce invalid candidates as it currently doesn't support discrete parameters (treated as floats), inequality constraints (completely ignored), and post processing functions (also ignored). These features were added to BoTorch as part of this diff stack. Here we update Ax to pass down the relevant parameters to `optimize_with_nsgaii` which ensures the generated candidates are within the search space. Reviewed By: sdaulton, saitcakmak Differential Revision: D95243318 --- .../torch/botorch_modular/acquisition.py | 21 ++--- ax/generators/torch/tests/test_acquisition.py | 76 ++++++++++++++++++- 2 files changed, 87 insertions(+), 10 deletions(-) diff --git a/ax/generators/torch/botorch_modular/acquisition.py b/ax/generators/torch/botorch_modular/acquisition.py index 31dc47b30ee..2175138cb54 100644 --- a/ax/generators/torch/botorch_modular/acquisition.py +++ b/ax/generators/torch/botorch_modular/acquisition.py @@ -47,7 +47,10 @@ from botorch.acquisition.input_constructors import get_acqf_input_constructor from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.acquisition.logei import qLogProbabilityOfFeasibility -from botorch.acquisition.multioutput_acquisition import MultiOutputAcquisitionFunction +from botorch.acquisition.multioutput_acquisition import ( + MultiOutputAcquisitionFunction, + MultiOutputAcquisitionFunctionWrapper, +) from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform from botorch.exceptions.errors import BotorchError, InputDataError from botorch.generation.sampling import SamplingStrategy @@ -66,7 +69,7 @@ ) from botorch.optim.parameter_constraints import evaluate_feasibility from botorch.utils.constraints import get_outcome_constraint_transforms -from pyre_extensions import none_throws +from pyre_extensions import assert_is_instance, none_throws from torch import Tensor try: @@ -816,18 +819,18 @@ def optimize( ) elif optimizer == "optimize_with_nsgaii": if optimize_with_nsgaii is not None: - # TODO: support post_processing_func + acqf = assert_is_instance( + self.acqf, MultiOutputAcquisitionFunctionWrapper + ) candidates, acqf_values = optimize_with_nsgaii( acq_function=self.acqf, bounds=bounds, q=n, fixed_features=fixed_features, - # We use pyre-ignore here to avoid a circular import. - # pyre-ignore [6]: Incompatible parameter type [6]: In call `len`, - # for 1st positional argument, expected - # `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Tensor, - # Module]`. - num_objectives=len(self.acqf.acqfs), + inequality_constraints=inequality_constraints, + num_objectives=len(acqf.acqfs), + discrete_choices=discrete_choices if discrete_choices else None, + post_processing_func=rounding_func, **optimizer_options_with_defaults, ) else: diff --git a/ax/generators/torch/tests/test_acquisition.py b/ax/generators/torch/tests/test_acquisition.py index 81ee466cf0d..61bfda3acc4 100644 --- a/ax/generators/torch/tests/test_acquisition.py +++ b/ax/generators/torch/tests/test_acquisition.py @@ -2226,8 +2226,11 @@ def test_optimize(self) -> None: acq_function=acquisition.acqf, bounds=mock.ANY, q=n, - num_objectives=2, fixed_features=self.fixed_features, + inequality_constraints=self.inequality_constraints, + num_objectives=2, + discrete_choices=mock.ANY, + post_processing_func=self.rounding_func, **optimizer_options, ) # can't use assert_called_with on bounds due to ambiguous bool comparison @@ -2242,6 +2245,77 @@ def test_optimize(self) -> None: ) ) + @skip_if_import_error + def test_optimize_with_nsgaii_features(self) -> None: + """Test that optimize_with_nsgaii correctly handles all features. + + This tests that candidates generated by optimize_with_nsgaii: + 1. Apply the post_processing_func (rounding) correctly + 2. Respect parameter-space inequality constraints + 3. Respect discrete parameter choices + """ + # Create a search space digest with irregularly-spaced discrete choices + # for dimension 0 (irregular spacing ensures simple rounding won't work) + discrete_search_space_digest = SearchSpaceDigest( + feature_names=self.feature_names, + bounds=[(0.0, 10.0), (0.0, 10.0), (0.0, 10.0)], + target_values={2: 1.0}, + ordinal_features=[0], + discrete_choices={0: [0.0, 2.0, 5.0, 10.0]}, + ) + + # Rounding function that rounds the third parameter (index 2) + def rounding_func(X: Tensor) -> Tensor: + X_rounded = X.clone() + X_rounded[..., 2] = X_rounded[..., 2].round() + return X_rounded + + acquisition = self.get_acquisition_function(fixed_features=self.fixed_features) + n = 5 + optimizer_options = {"max_gen": 5, "population_size": 20, "seed": 0} + + candidates, _, _ = acquisition.optimize( + n=n, + search_space_digest=discrete_search_space_digest, + inequality_constraints=self.inequality_constraints, + fixed_features=self.fixed_features, + rounding_func=rounding_func, + optimizer_options=optimizer_options, + ) + + # 1. Verify post_processing_func: dimension 2 should be rounded + self.assertTrue( + torch.equal(candidates[:, 2], candidates[:, 2].round()), + f"Third parameter should be rounded but got: {candidates[:, 2]}", + ) + + # 2. Verify inequality constraints: -x0 + x1 >= 1 + indices, coefficients, rhs = self.inequality_constraints[0] + for i in range(candidates.shape[0]): + constraint_value = ( + coefficients[0] * candidates[i, indices[0]] + + coefficients[1] * candidates[i, indices[1]] + ) + self.assertGreaterEqual( + constraint_value.item(), + rhs, + f"Candidate {i} violates inequality constraint: " + f"{constraint_value.item()} < {rhs}", + ) + + # 3. Verify discrete choices: dimension 0 should only have allowed values + allowed_values = torch.tensor( + discrete_search_space_digest.discrete_choices[0], **self.tkwargs + ) + for i in range(candidates.shape[0]): + val = candidates[i, 0] + is_valid = torch.any(torch.isclose(val, allowed_values)) + self.assertTrue( + is_valid, + f"Candidate {i} has invalid discrete value {val.item()} " + f"for dimension 0. Allowed: {allowed_values.tolist()}", + ) + def test_evaluate(self) -> None: acquisition = self.get_acquisition_function() with mock.patch.object(acquisition.acqf, "forward") as mock_forward: