Skip to content

Commit 2685f43

Browse files
David Erikssonmeta-codesync[bot]
authored andcommitted
Use discrete_values and post_processing_func in optimize_with_nsgaii (#4971)
Summary: Pull Request resolved: #4971 `optimize_with_nsgaii` may produce invalid candidates as it currently doesn't take into account whether a parameter is float or integer-valued. This diff addresses this by passing in `rounding_func` by default when `optimize_with_nsgaii` is called through MBM. Reviewed By: sdaulton, saitcakmak Differential Revision: D95243318
1 parent 1cd0b89 commit 2685f43

File tree

2 files changed

+50
-9
lines changed

2 files changed

+50
-9
lines changed

ax/generators/torch/botorch_modular/acquisition.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@
4747
from botorch.acquisition.input_constructors import get_acqf_input_constructor
4848
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
4949
from botorch.acquisition.logei import qLogProbabilityOfFeasibility
50-
from botorch.acquisition.multioutput_acquisition import MultiOutputAcquisitionFunction
50+
from botorch.acquisition.multioutput_acquisition import (
51+
MultiOutputAcquisitionFunction,
52+
MultiOutputAcquisitionFunctionWrapper,
53+
)
5154
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
5255
from botorch.exceptions.errors import BotorchError, InputDataError
5356
from botorch.generation.sampling import SamplingStrategy
@@ -66,7 +69,7 @@
6669
)
6770
from botorch.optim.parameter_constraints import evaluate_feasibility
6871
from botorch.utils.constraints import get_outcome_constraint_transforms
69-
from pyre_extensions import none_throws
72+
from pyre_extensions import assert_is_instance, none_throws
7073
from torch import Tensor
7174

7275
try:
@@ -816,18 +819,17 @@ def optimize(
816819
)
817820
elif optimizer == "optimize_with_nsgaii":
818821
if optimize_with_nsgaii is not None:
819-
# TODO: support post_processing_func
822+
acqf = assert_is_instance(
823+
self.acqf, MultiOutputAcquisitionFunctionWrapper
824+
)
820825
candidates, acqf_values = optimize_with_nsgaii(
821826
acq_function=self.acqf,
822827
bounds=bounds,
823828
q=n,
824829
fixed_features=fixed_features,
825-
# We use pyre-ignore here to avoid a circular import.
826-
# pyre-ignore [6]: Incompatible parameter type [6]: In call `len`,
827-
# for 1st positional argument, expected
828-
# `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Tensor,
829-
# Module]`.
830-
num_objectives=len(self.acqf.acqfs),
830+
num_objectives=len(acqf.acqfs),
831+
discrete_choices=discrete_choices if discrete_choices else None,
832+
post_processing_func=rounding_func,
831833
**optimizer_options_with_defaults,
832834
)
833835
else:

ax/generators/torch/tests/test_acquisition.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2228,6 +2228,7 @@ def test_optimize(self) -> None:
22282228
q=n,
22292229
num_objectives=2,
22302230
fixed_features=self.fixed_features,
2231+
post_processing_func=self.rounding_func,
22312232
**optimizer_options,
22322233
)
22332234
# can't use assert_called_with on bounds due to ambiguous bool comparison
@@ -2242,6 +2243,44 @@ def test_optimize(self) -> None:
22422243
)
22432244
)
22442245

2246+
@skip_if_import_error
2247+
def test_optimize_with_nsgaii_applies_rounding(self) -> None:
2248+
"""Test that optimize_with_nsgaii applies the rounding function correctly.
2249+
2250+
This tests a mixed setting where only some parameters are rounded (e.g.,
2251+
integer-valued), while others remain continuous.
2252+
"""
2253+
2254+
# Create a rounding function that only rounds the first parameter (index 0)
2255+
# This simulates a mixed search space with one integer parameter
2256+
def rounding_func(X: Tensor) -> Tensor:
2257+
X_rounded = X.clone()
2258+
X_rounded[..., 0] = X_rounded[..., 0].round()
2259+
return X_rounded
2260+
2261+
acquisition = self.get_acquisition_function(fixed_features=self.fixed_features)
2262+
n = 3
2263+
optimizer_options = {"max_gen": 5, "population_size": 20, "seed": 0}
2264+
2265+
candidates, _, _ = acquisition.optimize(
2266+
n=n,
2267+
search_space_digest=self.search_space_digest,
2268+
inequality_constraints=self.inequality_constraints,
2269+
fixed_features=self.fixed_features,
2270+
rounding_func=rounding_func,
2271+
optimizer_options=optimizer_options,
2272+
)
2273+
2274+
# Assert that the first parameter (index 0) is rounded to integers
2275+
self.assertTrue(
2276+
torch.equal(candidates[:, 0], candidates[:, 0].round()),
2277+
f"First parameter should be rounded but got: {candidates[:, 0]}",
2278+
)
2279+
# Assert that other parameters (indices 1, 2) may have non-integer values
2280+
# (they won't necessarily be non-integers, but they weren't explicitly rounded)
2281+
# The key assertion is that the rounding function was applied to produce
2282+
# integer values for the first parameter
2283+
22452284
def test_evaluate(self) -> None:
22462285
acquisition = self.get_acquisition_function()
22472286
with mock.patch.object(acquisition.acqf, "forward") as mock_forward:

0 commit comments

Comments
 (0)