Skip to content

Commit 6f07eff

Browse files
David Erikssonfacebook-github-bot
authored andcommitted
Use discrete_values and post_processing_func in optimize_with_nsgaii (facebook#4971)
Summary: `optimize_with_nsgaii` may produce invalid candidates as it currently doesn't support discrete parameters (treated as floats), inequality constraints (completely ignored), and post processing functions (also ignored). These features were added to BoTorch as part of this diff stack. Here we update Ax to pass down the relevant parameters to `optimize_with_nsgaii` which ensures the generated candidates are within the search space. Reviewed By: sdaulton, saitcakmak Differential Revision: D95243318
1 parent 452ddfb commit 6f07eff

File tree

2 files changed

+87
-10
lines changed

2 files changed

+87
-10
lines changed

ax/generators/torch/botorch_modular/acquisition.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@
4747
from botorch.acquisition.input_constructors import get_acqf_input_constructor
4848
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
4949
from botorch.acquisition.logei import qLogProbabilityOfFeasibility
50-
from botorch.acquisition.multioutput_acquisition import MultiOutputAcquisitionFunction
50+
from botorch.acquisition.multioutput_acquisition import (
51+
MultiOutputAcquisitionFunction,
52+
MultiOutputAcquisitionFunctionWrapper,
53+
)
5154
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
5255
from botorch.exceptions.errors import BotorchError, InputDataError
5356
from botorch.generation.sampling import SamplingStrategy
@@ -66,7 +69,7 @@
6669
)
6770
from botorch.optim.parameter_constraints import evaluate_feasibility
6871
from botorch.utils.constraints import get_outcome_constraint_transforms
69-
from pyre_extensions import none_throws
72+
from pyre_extensions import assert_is_instance, none_throws
7073
from torch import Tensor
7174

7275
try:
@@ -816,18 +819,18 @@ def optimize(
816819
)
817820
elif optimizer == "optimize_with_nsgaii":
818821
if optimize_with_nsgaii is not None:
819-
# TODO: support post_processing_func
822+
acqf = assert_is_instance(
823+
self.acqf, MultiOutputAcquisitionFunctionWrapper
824+
)
820825
candidates, acqf_values = optimize_with_nsgaii(
821826
acq_function=self.acqf,
822827
bounds=bounds,
823828
q=n,
824829
fixed_features=fixed_features,
825-
# We use pyre-ignore here to avoid a circular import.
826-
# pyre-ignore [6]: Incompatible parameter type [6]: In call `len`,
827-
# for 1st positional argument, expected
828-
# `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Tensor,
829-
# Module]`.
830-
num_objectives=len(self.acqf.acqfs),
830+
inequality_constraints=inequality_constraints,
831+
num_objectives=len(acqf.acqfs),
832+
discrete_choices=discrete_choices if discrete_choices else None,
833+
post_processing_func=rounding_func,
831834
**optimizer_options_with_defaults,
832835
)
833836
else:

ax/generators/torch/tests/test_acquisition.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2226,8 +2226,11 @@ def test_optimize(self) -> None:
22262226
acq_function=acquisition.acqf,
22272227
bounds=mock.ANY,
22282228
q=n,
2229-
num_objectives=2,
22302229
fixed_features=self.fixed_features,
2230+
inequality_constraints=self.inequality_constraints,
2231+
num_objectives=2,
2232+
discrete_choices=mock.ANY,
2233+
post_processing_func=self.rounding_func,
22312234
**optimizer_options,
22322235
)
22332236
# can't use assert_called_with on bounds due to ambiguous bool comparison
@@ -2242,6 +2245,77 @@ def test_optimize(self) -> None:
22422245
)
22432246
)
22442247

2248+
@skip_if_import_error
2249+
def test_optimize_with_nsgaii_features(self) -> None:
2250+
"""Test that optimize_with_nsgaii correctly handles all features.
2251+
2252+
This tests that candidates generated by optimize_with_nsgaii:
2253+
1. Apply the post_processing_func (rounding) correctly
2254+
2. Respect parameter-space inequality constraints
2255+
3. Respect discrete parameter choices
2256+
"""
2257+
# Create a search space digest with irregularly-spaced discrete choices
2258+
# for dimension 0 (irregular spacing ensures simple rounding won't work)
2259+
discrete_search_space_digest = SearchSpaceDigest(
2260+
feature_names=self.feature_names,
2261+
bounds=[(0.0, 10.0), (0.0, 10.0), (0.0, 10.0)],
2262+
target_values={2: 1.0},
2263+
ordinal_features=[0],
2264+
discrete_choices={0: [0.0, 2.0, 5.0, 10.0]},
2265+
)
2266+
2267+
# Rounding function that rounds the third parameter (index 2)
2268+
def rounding_func(X: Tensor) -> Tensor:
2269+
X_rounded = X.clone()
2270+
X_rounded[..., 2] = X_rounded[..., 2].round()
2271+
return X_rounded
2272+
2273+
acquisition = self.get_acquisition_function(fixed_features=self.fixed_features)
2274+
n = 5
2275+
optimizer_options = {"max_gen": 5, "population_size": 20, "seed": 0}
2276+
2277+
candidates, _, _ = acquisition.optimize(
2278+
n=n,
2279+
search_space_digest=discrete_search_space_digest,
2280+
inequality_constraints=self.inequality_constraints,
2281+
fixed_features=self.fixed_features,
2282+
rounding_func=rounding_func,
2283+
optimizer_options=optimizer_options,
2284+
)
2285+
2286+
# 1. Verify post_processing_func: dimension 2 should be rounded
2287+
self.assertTrue(
2288+
torch.equal(candidates[:, 2], candidates[:, 2].round()),
2289+
f"Third parameter should be rounded but got: {candidates[:, 2]}",
2290+
)
2291+
2292+
# 2. Verify inequality constraints: -x0 + x1 >= 1
2293+
indices, coefficients, rhs = self.inequality_constraints[0]
2294+
for i in range(candidates.shape[0]):
2295+
constraint_value = (
2296+
coefficients[0] * candidates[i, indices[0]]
2297+
+ coefficients[1] * candidates[i, indices[1]]
2298+
)
2299+
self.assertGreaterEqual(
2300+
constraint_value.item(),
2301+
rhs,
2302+
f"Candidate {i} violates inequality constraint: "
2303+
f"{constraint_value.item()} < {rhs}",
2304+
)
2305+
2306+
# 3. Verify discrete choices: dimension 0 should only have allowed values
2307+
allowed_values = torch.tensor(
2308+
discrete_search_space_digest.discrete_choices[0], **self.tkwargs
2309+
)
2310+
for i in range(candidates.shape[0]):
2311+
val = candidates[i, 0]
2312+
is_valid = torch.any(torch.isclose(val, allowed_values))
2313+
self.assertTrue(
2314+
is_valid,
2315+
f"Candidate {i} has invalid discrete value {val.item()} "
2316+
f"for dimension 0. Allowed: {allowed_values.tolist()}",
2317+
)
2318+
22452319
def test_evaluate(self) -> None:
22462320
acquisition = self.get_acquisition_function()
22472321
with mock.patch.object(acquisition.acqf, "forward") as mock_forward:

0 commit comments

Comments
 (0)