Skip to content

Commit 5b9bf62

Browse files
David Erikssonfacebook-github-bot
authored andcommitted
Use discrete_values and post_processing_func in optimize_with_nsgaii (#4971)
Summary: `optimize_with_nsgaii` may produce invalid candidates as it currently doesn't support discrete parameters (treated as floats), inequality constraints (completely ignored), and post processing functions (also ignored). These features were added to BoTorch as part of this diff stack. Here we update Ax to pass down the relevant parameters to `optimize_with_nsgaii` which ensures the generated candidates are within the search space. Reviewed By: sdaulton, saitcakmak Differential Revision: D95243318
1 parent 05f34d5 commit 5b9bf62

2 files changed

Lines changed: 93 additions & 11 deletions

File tree

ax/generators/torch/botorch_modular/acquisition.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@
4747
from botorch.acquisition.input_constructors import get_acqf_input_constructor
4848
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
4949
from botorch.acquisition.logei import qLogProbabilityOfFeasibility
50-
from botorch.acquisition.multioutput_acquisition import MultiOutputAcquisitionFunction
50+
from botorch.acquisition.multioutput_acquisition import (
51+
MultiOutputAcquisitionFunction,
52+
MultiOutputAcquisitionFunctionWrapper,
53+
)
5154
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
5255
from botorch.exceptions.errors import BotorchError, InputDataError
5356
from botorch.generation.sampling import SamplingStrategy
@@ -66,7 +69,7 @@
6669
)
6770
from botorch.optim.parameter_constraints import evaluate_feasibility
6871
from botorch.utils.constraints import get_outcome_constraint_transforms
69-
from pyre_extensions import none_throws
72+
from pyre_extensions import assert_is_instance, none_throws
7073
from torch import Tensor
7174

7275
try:
@@ -816,18 +819,18 @@ def optimize(
816819
)
817820
elif optimizer == "optimize_with_nsgaii":
818821
if optimize_with_nsgaii is not None:
819-
# TODO: support post_processing_func
822+
acqf = assert_is_instance(
823+
self.acqf, MultiOutputAcquisitionFunctionWrapper
824+
)
820825
candidates, acqf_values = optimize_with_nsgaii(
821826
acq_function=self.acqf,
822827
bounds=bounds,
823828
q=n,
824829
fixed_features=fixed_features,
825-
# We use pyre-ignore here to avoid a circular import.
826-
# pyre-ignore [6]: Incompatible parameter type [6]: In call `len`,
827-
# for 1st positional argument, expected
828-
# `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Tensor,
829-
# Module]`.
830-
num_objectives=len(self.acqf.acqfs),
830+
inequality_constraints=inequality_constraints,
831+
num_objectives=len(acqf.acqfs),
832+
discrete_choices=discrete_choices if discrete_choices else None,
833+
post_processing_func=rounding_func,
831834
**optimizer_options_with_defaults,
832835
)
833836
else:

ax/generators/torch/tests/test_acquisition.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2122,6 +2122,9 @@ def test_optimize_mixed(self) -> None:
21222122
def test_optimize_acqf_mixed_alternating(self) -> None:
21232123
pass
21242124

2125+
def test_no_pruning_with_qLogProbabilityOfFeasibility(self) -> None:
2126+
pass
2127+
21252128
def test_select_from_candidate_set(self) -> None:
21262129
pass
21272130

@@ -2200,7 +2203,9 @@ def test_optimize(self) -> None:
22002203

22012204
acquisition = self.get_acquisition_function(fixed_features=self.fixed_features)
22022205
n = 5
2203-
optimizer_options = {"max_gen": 3, "population_size": 10}
2206+
# Use more generations and larger population to reliably find feasible
2207+
# candidates that satisfy the inequality constraint
2208+
optimizer_options = {"max_gen": 10, "population_size": 50}
22042209
with (
22052210
mock.patch(
22062211
f"{ACQUISITION_PATH}.optimizer_argparse", wraps=optimizer_argparse
@@ -2226,8 +2231,11 @@ def test_optimize(self) -> None:
22262231
acq_function=acquisition.acqf,
22272232
bounds=mock.ANY,
22282233
q=n,
2229-
num_objectives=2,
22302234
fixed_features=self.fixed_features,
2235+
inequality_constraints=self.inequality_constraints,
2236+
num_objectives=2,
2237+
discrete_choices=mock.ANY,
2238+
post_processing_func=self.rounding_func,
22312239
**optimizer_options,
22322240
)
22332241
# can't use assert_called_with on bounds due to ambiguous bool comparison
@@ -2242,6 +2250,77 @@ def test_optimize(self) -> None:
22422250
)
22432251
)
22442252

2253+
@skip_if_import_error
2254+
def test_optimize_with_nsgaii_features(self) -> None:
2255+
"""Test that optimize_with_nsgaii correctly handles all features.
2256+
2257+
This tests that candidates generated by optimize_with_nsgaii:
2258+
1. Apply the post_processing_func (rounding) correctly
2259+
2. Respect parameter-space inequality constraints
2260+
3. Respect discrete parameter choices
2261+
"""
2262+
# Create a search space digest with irregularly-spaced discrete choices
2263+
# for dimension 0 (irregular spacing ensures simple rounding won't work)
2264+
discrete_search_space_digest = SearchSpaceDigest(
2265+
feature_names=self.feature_names,
2266+
bounds=[(0.0, 10.0), (0.0, 10.0), (0.0, 10.0)],
2267+
target_values={2: 1.0},
2268+
ordinal_features=[0],
2269+
discrete_choices={0: [0.0, 2.0, 5.0, 10.0]},
2270+
)
2271+
2272+
# Rounding function that rounds the third parameter (index 2)
2273+
def rounding_func(X: Tensor) -> Tensor:
2274+
X_rounded = X.clone()
2275+
X_rounded[..., 2] = X_rounded[..., 2].round()
2276+
return X_rounded
2277+
2278+
acquisition = self.get_acquisition_function(fixed_features=self.fixed_features)
2279+
n = 5
2280+
optimizer_options = {"max_gen": 5, "population_size": 20, "seed": 0}
2281+
2282+
candidates, _, _ = acquisition.optimize(
2283+
n=n,
2284+
search_space_digest=discrete_search_space_digest,
2285+
inequality_constraints=self.inequality_constraints,
2286+
fixed_features=self.fixed_features,
2287+
rounding_func=rounding_func,
2288+
optimizer_options=optimizer_options,
2289+
)
2290+
2291+
# 1. Verify post_processing_func: dimension 2 should be rounded
2292+
self.assertTrue(
2293+
torch.equal(candidates[:, 2], candidates[:, 2].round()),
2294+
f"Third parameter should be rounded but got: {candidates[:, 2]}",
2295+
)
2296+
2297+
# 2. Verify inequality constraints: -x0 + x1 >= 1
2298+
indices, coefficients, rhs = self.inequality_constraints[0]
2299+
for i in range(candidates.shape[0]):
2300+
constraint_value = (
2301+
coefficients[0] * candidates[i, indices[0]]
2302+
+ coefficients[1] * candidates[i, indices[1]]
2303+
)
2304+
self.assertGreaterEqual(
2305+
constraint_value.item(),
2306+
rhs,
2307+
f"Candidate {i} violates inequality constraint: "
2308+
f"{constraint_value.item()} < {rhs}",
2309+
)
2310+
2311+
# 3. Verify discrete choices: dimension 0 should only have allowed values
2312+
allowed_values = torch.tensor(
2313+
discrete_search_space_digest.discrete_choices[0], **self.tkwargs
2314+
)
2315+
for i in range(candidates.shape[0]):
2316+
val = candidates[i, 0]
2317+
is_valid = torch.any(torch.isclose(val, allowed_values))
2318+
self.assertTrue(
2319+
is_valid,
2320+
f"Candidate {i} has invalid discrete value {val.item()} "
2321+
f"for dimension 0. Allowed: {allowed_values.tolist()}",
2322+
)
2323+
22452324
def test_evaluate(self) -> None:
22462325
acquisition = self.get_acquisition_function()
22472326
with mock.patch.object(acquisition.acqf, "forward") as mock_forward:

0 commit comments

Comments
 (0)