Skip to content

Commit

Permalink
fix constraint handling in single objective MBM (#1973)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebook/Ax#1771

Pull Request resolved: #1973

Currently, constraints are not used in single objective AFs in MBM due to a name mismatch between `outcome_constraints` and `constraints`.

Reviewed By: SebastianAment

Differential Revision: D48176978

fbshipit-source-id: 9495708002c11a874bb6b8c06327f0f4643039df
  • Loading branch information
sdaulton authored and facebook-github-bot committed Aug 10, 2023
1 parent 3506538 commit 52529e1
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 97 deletions.
27 changes: 10 additions & 17 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@
from botorch.optim.optimize import optimize_acqf
from botorch.sampling.base import MCSampler
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.constraints import get_outcome_constraint_transforms
from botorch.utils.containers import BotorchContainer
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
Expand Down Expand Up @@ -718,7 +717,7 @@ def construct_inputs_qLogNEI(
X_baseline=X_baseline,
prune_baseline=prune_baseline,
cache_root=cache_root,
constraint=constraints,
constraints=constraints,
eta=eta,
),
"fat": fat,
Expand Down Expand Up @@ -853,11 +852,12 @@ def construct_inputs_EHVI(
training_data: MaybeDict[SupervisedDataset],
objective_thresholds: Tensor,
objective: Optional[AnalyticMultiOutputObjective] = None,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
r"""Construct kwargs for `ExpectedHypervolumeImprovement` constructor."""
num_objectives = objective_thresholds.shape[0]
if kwargs.get("outcome_constraints") is not None:
if constraints is not None:
raise NotImplementedError("EHVI does not yet support outcome constraints.")

X = _get_dataset_field(
Expand Down Expand Up @@ -914,6 +914,7 @@ def construct_inputs_qEHVI(
training_data: MaybeDict[SupervisedDataset],
objective_thresholds: Tensor,
objective: Optional[MCMultiOutputObjective] = None,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
r"""Construct kwargs for `qExpectedHypervolumeImprovement` constructor."""
Expand All @@ -928,15 +929,10 @@ def construct_inputs_qEHVI(
# compute posterior mean (for ref point computation ref pareto frontier)
with torch.no_grad():
Y_pmean = model.posterior(X).mean

outcome_constraints = kwargs.pop("outcome_constraints", None)
# For HV-based acquisition functions we pass the constraint transform directly
if outcome_constraints is None:
cons_tfs = None
else:
cons_tfs = get_outcome_constraint_transforms(outcome_constraints)
if constraints is not None:
# Adjust `Y_pmean` to contrain feasible points only.
feas = torch.stack([c(Y_pmean) <= 0 for c in cons_tfs], dim=-1).all(dim=-1)
feas = torch.stack([c(Y_pmean) <= 0 for c in constraints], dim=-1).all(dim=-1)
Y_pmean = Y_pmean[feas]

if objective is None:
Expand All @@ -962,7 +958,7 @@ def construct_inputs_qEHVI(
add_qehvi_kwargs = {
"sampler": sampler,
"X_pending": kwargs.get("X_pending"),
"constraints": cons_tfs,
"constraints": constraints,
"eta": kwargs.get("eta", 1e-3),
}
return {**ehvi_kwargs, **add_qehvi_kwargs}
Expand All @@ -975,6 +971,7 @@ def construct_inputs_qNEHVI(
objective_thresholds: Tensor,
objective: Optional[MCMultiOutputObjective] = None,
X_baseline: Optional[Tensor] = None,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
r"""Construct kwargs for `qNoisyExpectedHypervolumeImprovement` constructor."""
Expand All @@ -991,16 +988,12 @@ def construct_inputs_qNEHVI(
if objective is None:
objective = IdentityMCMultiOutputObjective()

outcome_constraints = kwargs.pop("outcome_constraints", None)
if outcome_constraints is None:
cons_tfs = None
else:
if constraints is not None:
if isinstance(objective, RiskMeasureMCObjective):
raise UnsupportedError(
"Outcome constraints are not supported with risk measures. "
"Use a feasibility-weighted risk measure instead."
)
cons_tfs = get_outcome_constraint_transforms(outcome_constraints)

sampler = kwargs.get("sampler")
if sampler is None and isinstance(model, GPyTorchModel):
Expand All @@ -1021,7 +1014,7 @@ def construct_inputs_qNEHVI(
"X_baseline": X_baseline,
"sampler": sampler,
"objective": objective,
"constraints": cons_tfs,
"constraints": constraints,
"X_pending": kwargs.get("X_pending"),
"eta": kwargs.get("eta", 1e-3),
"prune_baseline": kwargs.get("prune_baseline", True),
Expand Down
17 changes: 10 additions & 7 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,12 @@ def compute_best_feasible_objective(
is_feasible = compute_feasibility_indicator(
constraints=constraints, samples=samples
) # sample_shape x batch_shape x q
if is_feasible.any():
obj = torch.where(is_feasible, obj, -torch.inf)
with torch.no_grad():
return obj.amax(dim=-1, keepdim=True)

if is_feasible.any(dim=-1).all():
infeasible_value = -torch.inf

elif infeasible_obj is not None:
return infeasible_obj.expand(*obj.shape[:-1], 1)
infeasible_value = infeasible_obj.item()

else:
if model is None:
Expand All @@ -323,12 +322,16 @@ def compute_best_feasible_objective(
raise ValueError(
"Must specify `X_baseline` when no feasible observation exists."
)
return _estimate_objective_lower_bound(
infeasible_value = _estimate_objective_lower_bound(
model=model,
objective=objective,
posterior_transform=posterior_transform,
X=X_baseline,
).expand(*obj.shape[:-1], 1)
).item()

obj = torch.where(is_feasible, obj, infeasible_value)
with torch.no_grad():
return obj.amax(dim=-1, keepdim=True)


def _estimate_objective_lower_bound(
Expand Down
62 changes: 44 additions & 18 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,6 @@ def test_construct_inputs_qEI(self):
self.assertTrue(torch.equal(kwargs["objective"].weights, objective.weights))
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
self.assertIsNone(kwargs["sampler"])
self.assertIsNone(kwargs["constraints"])
self.assertIsInstance(kwargs["eta"], float)
self.assertTrue(kwargs["eta"] < 1)
multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1)
Expand All @@ -406,6 +405,20 @@ def test_construct_inputs_qEI(self):
best_f=best_f_expected,
)
self.assertEqual(kwargs["best_f"], best_f_expected)
# test passing constraints
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
constraints = get_outcome_constraint_transforms(
outcome_constraints=outcome_constraints
)
kwargs = c(
model=mock_model,
training_data=self.blockX_multiY,
objective=objective,
X_pending=X_pending,
best_f=best_f_expected,
constraints=constraints,
)
self.assertIs(kwargs["constraints"], constraints)

# testing qLogEI input constructor
log_constructor = get_acqf_input_constructor(qLogExpectedImprovement)
Expand All @@ -415,6 +428,7 @@ def test_construct_inputs_qEI(self):
objective=objective,
X_pending=X_pending,
best_f=best_f_expected,
constraints=constraints,
)
# includes strict superset of kwargs tested above
self.assertTrue(kwargs.items() <= log_kwargs.items())
Expand All @@ -423,6 +437,7 @@ def test_construct_inputs_qEI(self):
self.assertEqual(log_kwargs["tau_max"], TAU_MAX)
self.assertTrue("tau_relu" in log_kwargs)
self.assertEqual(log_kwargs["tau_relu"], TAU_RELU)
self.assertIs(log_kwargs["constraints"], constraints)

def test_construct_inputs_qNEI(self):
c = get_acqf_input_constructor(qNoisyExpectedImprovement)
Expand All @@ -441,29 +456,36 @@ def test_construct_inputs_qNEI(self):
with self.assertRaisesRegex(ValueError, "Field `X` must be shared"):
c(model=mock_model, training_data=self.multiX_multiY)
X_baseline = torch.rand(2, 2)
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
constraints = get_outcome_constraint_transforms(
outcome_constraints=outcome_constraints
)
kwargs = c(
model=mock_model,
training_data=self.blockX_blockY,
X_baseline=X_baseline,
prune_baseline=False,
constraints=constraints,
)
self.assertEqual(kwargs["model"], mock_model)
self.assertIsNone(kwargs["objective"])
self.assertIsNone(kwargs["X_pending"])
self.assertIsNone(kwargs["sampler"])
self.assertFalse(kwargs["prune_baseline"])
self.assertTrue(torch.equal(kwargs["X_baseline"], X_baseline))
self.assertIsNone(kwargs["constraints"])
self.assertIsInstance(kwargs["eta"], float)
self.assertTrue(kwargs["eta"] < 1)
self.assertIs(kwargs["constraints"], constraints)

# testing qLogNEI input constructor
log_constructor = get_acqf_input_constructor(qLogNoisyExpectedImprovement)

log_kwargs = log_constructor(
model=mock_model,
training_data=self.blockX_blockY,
X_baseline=X_baseline,
prune_baseline=False,
constraints=constraints,
)
# includes strict superset of kwargs tested above
self.assertTrue(kwargs.items() <= log_kwargs.items())
Expand All @@ -472,6 +494,7 @@ def test_construct_inputs_qNEI(self):
self.assertEqual(log_kwargs["tau_max"], TAU_MAX)
self.assertTrue("tau_relu" in log_kwargs)
self.assertEqual(log_kwargs["tau_relu"], TAU_RELU)
self.assertIs(log_kwargs["constraints"], constraints)

def test_construct_inputs_qPI(self):
c = get_acqf_input_constructor(qProbabilityOfImprovement)
Expand Down Expand Up @@ -499,23 +522,28 @@ def test_construct_inputs_qPI(self):
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
self.assertIsNone(kwargs["sampler"])
self.assertEqual(kwargs["tau"], 1e-2)
self.assertIsNone(kwargs["constraints"])
self.assertIsInstance(kwargs["eta"], float)
self.assertTrue(kwargs["eta"] < 1)
multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1)
best_f_expected = objective(multi_Y).max()
self.assertEqual(kwargs["best_f"], best_f_expected)
# Check explicitly specifying `best_f`.
best_f_expected = best_f_expected - 1 # Random value.
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
constraints = get_outcome_constraint_transforms(
outcome_constraints=outcome_constraints
)
kwargs = c(
model=mock_model,
training_data=self.blockX_multiY,
objective=objective,
X_pending=X_pending,
tau=1e-2,
best_f=best_f_expected,
constraints=constraints,
)
self.assertEqual(kwargs["best_f"], best_f_expected)
self.assertIs(kwargs["constraints"], constraints)

def test_construct_inputs_qUCB(self):
c = get_acqf_input_constructor(qUpperConfidenceBound)
Expand Down Expand Up @@ -564,7 +592,7 @@ def test_construct_inputs_EHVI(self):
model=mock_model,
training_data=self.blockX_blockY,
objective_thresholds=objective_thresholds,
outcome_constraints=mock.Mock(),
constraints=mock.Mock(),
)

# test with Y_pmean supplied explicitly
Expand Down Expand Up @@ -702,13 +730,16 @@ def test_construct_inputs_qEHVI(self):
weights = torch.rand(2)
obj = WeightedMCMultiOutputObjective(weights=weights)
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
constraints = get_outcome_constraint_transforms(
outcome_constraints=outcome_constraints
)
X_pending = torch.rand(1, 2)
kwargs = c(
model=mm,
training_data=self.blockX_blockY,
objective_thresholds=objective_thresholds,
objective=obj,
outcome_constraints=outcome_constraints,
constraints=constraints,
X_pending=X_pending,
alpha=0.05,
eta=1e-2,
Expand All @@ -723,11 +754,7 @@ def test_construct_inputs_qEHVI(self):
Y_expected = mean[:1] * weights
self.assertTrue(torch.equal(partitioning._neg_Y, -Y_expected))
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
cons_tfs = kwargs["constraints"]
self.assertEqual(len(cons_tfs), 1)
cons_eval = cons_tfs[0](mean)
cons_eval_expected = torch.tensor([-0.25, 0.5])
self.assertTrue(torch.equal(cons_eval, cons_eval_expected))
self.assertIs(kwargs["constraints"], constraints)
self.assertEqual(kwargs["eta"], 1e-2)

# Test check for block designs
Expand All @@ -737,7 +764,7 @@ def test_construct_inputs_qEHVI(self):
training_data=self.multiX_multiY,
objective_thresholds=objective_thresholds,
objective=obj,
outcome_constraints=outcome_constraints,
constraints=constraints,
X_pending=X_pending,
alpha=0.05,
eta=1e-2,
Expand Down Expand Up @@ -798,6 +825,9 @@ def test_construct_inputs_qNEHVI(self):
X_baseline = torch.rand(2, 2)
sampler = IIDNormalSampler(sample_shape=torch.Size([4]))
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
constraints = get_outcome_constraint_transforms(
outcome_constraints=outcome_constraints
)
X_pending = torch.rand(1, 2)
kwargs = c(
model=mock_model,
Expand All @@ -806,7 +836,7 @@ def test_construct_inputs_qNEHVI(self):
objective=objective,
X_baseline=X_baseline,
sampler=sampler,
outcome_constraints=outcome_constraints,
constraints=constraints,
X_pending=X_pending,
eta=1e-2,
prune_baseline=True,
Expand All @@ -823,11 +853,7 @@ def test_construct_inputs_qNEHVI(self):
self.assertIsInstance(sampler_, IIDNormalSampler)
self.assertEqual(sampler_.sample_shape, torch.Size([4]))
self.assertEqual(kwargs["objective"], objective)
cons_tfs_expected = get_outcome_constraint_transforms(outcome_constraints)
cons_tfs = kwargs["constraints"]
self.assertEqual(len(cons_tfs), 1)
test_Y = torch.rand(1, 2)
self.assertTrue(torch.equal(cons_tfs[0](test_Y), cons_tfs_expected[0](test_Y)))
self.assertIs(kwargs["constraints"], constraints)
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
self.assertEqual(kwargs["eta"], 1e-2)
self.assertTrue(kwargs["prune_baseline"])
Expand All @@ -844,7 +870,7 @@ def test_construct_inputs_qNEHVI(self):
training_data=self.blockX_blockY,
objective_thresholds=objective_thresholds,
objective=MultiOutputExpectation(n_w=3),
outcome_constraints=outcome_constraints,
constraints=constraints,
)
for use_preprocessing in (True, False):
obj = MultiOutputExpectation(
Expand Down
Loading

0 comments on commit 52529e1

Please sign in to comment.