From 52529e1130c9fb80e9bde8ecf5a3d5445a3c2e12 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Thu, 10 Aug 2023 07:23:19 -0700 Subject: [PATCH] fix constraint handling in single objective MBM (#1973) Summary: X-link: https://github.com/facebook/Ax/pull/1771 Pull Request resolved: https://github.com/pytorch/botorch/pull/1973 Currently, constraints are not used in single objective AFs in MBM due to a name mismatch between `outcome_constraints` and `constraints`. Reviewed By: SebastianAment Differential Revision: D48176978 fbshipit-source-id: 9495708002c11a874bb6b8c06327f0f4643039df --- botorch/acquisition/input_constructors.py | 27 ++-- botorch/acquisition/utils.py | 17 +- test/acquisition/test_input_constructors.py | 62 ++++--- test/acquisition/test_utils.py | 169 +++++++++++++------- 4 files changed, 178 insertions(+), 97 deletions(-) diff --git a/botorch/acquisition/input_constructors.py b/botorch/acquisition/input_constructors.py index ae000201f1..ec226451c9 100644 --- a/botorch/acquisition/input_constructors.py +++ b/botorch/acquisition/input_constructors.py @@ -97,7 +97,6 @@ from botorch.optim.optimize import optimize_acqf from botorch.sampling.base import MCSampler from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler -from botorch.utils.constraints import get_outcome_constraint_transforms from botorch.utils.containers import BotorchContainer from botorch.utils.datasets import SupervisedDataset from botorch.utils.multi_objective.box_decompositions.non_dominated import ( @@ -718,7 +717,7 @@ def construct_inputs_qLogNEI( X_baseline=X_baseline, prune_baseline=prune_baseline, cache_root=cache_root, - constraint=constraints, + constraints=constraints, eta=eta, ), "fat": fat, @@ -853,11 +852,12 @@ def construct_inputs_EHVI( training_data: MaybeDict[SupervisedDataset], objective_thresholds: Tensor, objective: Optional[AnalyticMultiOutputObjective] = None, + constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, **kwargs: Any, ) -> Dict[str, Any]: r"""Construct kwargs for `ExpectedHypervolumeImprovement` constructor.""" num_objectives = objective_thresholds.shape[0] - if kwargs.get("outcome_constraints") is not None: + if constraints is not None: raise NotImplementedError("EHVI does not yet support outcome constraints.") X = _get_dataset_field( @@ -914,6 +914,7 @@ def construct_inputs_qEHVI( training_data: MaybeDict[SupervisedDataset], objective_thresholds: Tensor, objective: Optional[MCMultiOutputObjective] = None, + constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, **kwargs: Any, ) -> Dict[str, Any]: r"""Construct kwargs for `qExpectedHypervolumeImprovement` constructor.""" @@ -928,15 +929,10 @@ def construct_inputs_qEHVI( # compute posterior mean (for ref point computation ref pareto frontier) with torch.no_grad(): Y_pmean = model.posterior(X).mean - - outcome_constraints = kwargs.pop("outcome_constraints", None) # For HV-based acquisition functions we pass the constraint transform directly - if outcome_constraints is None: - cons_tfs = None - else: - cons_tfs = get_outcome_constraint_transforms(outcome_constraints) + if constraints is not None: # Adjust `Y_pmean` to contrain feasible points only. - feas = torch.stack([c(Y_pmean) <= 0 for c in cons_tfs], dim=-1).all(dim=-1) + feas = torch.stack([c(Y_pmean) <= 0 for c in constraints], dim=-1).all(dim=-1) Y_pmean = Y_pmean[feas] if objective is None: @@ -962,7 +958,7 @@ def construct_inputs_qEHVI( add_qehvi_kwargs = { "sampler": sampler, "X_pending": kwargs.get("X_pending"), - "constraints": cons_tfs, + "constraints": constraints, "eta": kwargs.get("eta", 1e-3), } return {**ehvi_kwargs, **add_qehvi_kwargs} @@ -975,6 +971,7 @@ def construct_inputs_qNEHVI( objective_thresholds: Tensor, objective: Optional[MCMultiOutputObjective] = None, X_baseline: Optional[Tensor] = None, + constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, **kwargs: Any, ) -> Dict[str, Any]: r"""Construct kwargs for `qNoisyExpectedHypervolumeImprovement` constructor.""" @@ -991,16 +988,12 @@ def construct_inputs_qNEHVI( if objective is None: objective = IdentityMCMultiOutputObjective() - outcome_constraints = kwargs.pop("outcome_constraints", None) - if outcome_constraints is None: - cons_tfs = None - else: + if constraints is not None: if isinstance(objective, RiskMeasureMCObjective): raise UnsupportedError( "Outcome constraints are not supported with risk measures. " "Use a feasibility-weighted risk measure instead." ) - cons_tfs = get_outcome_constraint_transforms(outcome_constraints) sampler = kwargs.get("sampler") if sampler is None and isinstance(model, GPyTorchModel): @@ -1021,7 +1014,7 @@ def construct_inputs_qNEHVI( "X_baseline": X_baseline, "sampler": sampler, "objective": objective, - "constraints": cons_tfs, + "constraints": constraints, "X_pending": kwargs.get("X_pending"), "eta": kwargs.get("eta", 1e-3), "prune_baseline": kwargs.get("prune_baseline", True), diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 13cf482cef..5364ef5fd5 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -306,13 +306,12 @@ def compute_best_feasible_objective( is_feasible = compute_feasibility_indicator( constraints=constraints, samples=samples ) # sample_shape x batch_shape x q - if is_feasible.any(): - obj = torch.where(is_feasible, obj, -torch.inf) - with torch.no_grad(): - return obj.amax(dim=-1, keepdim=True) + + if is_feasible.any(dim=-1).all(): + infeasible_value = -torch.inf elif infeasible_obj is not None: - return infeasible_obj.expand(*obj.shape[:-1], 1) + infeasible_value = infeasible_obj.item() else: if model is None: @@ -323,12 +322,16 @@ def compute_best_feasible_objective( raise ValueError( "Must specify `X_baseline` when no feasible observation exists." ) - return _estimate_objective_lower_bound( + infeasible_value = _estimate_objective_lower_bound( model=model, objective=objective, posterior_transform=posterior_transform, X=X_baseline, - ).expand(*obj.shape[:-1], 1) + ).item() + + obj = torch.where(is_feasible, obj, infeasible_value) + with torch.no_grad(): + return obj.amax(dim=-1, keepdim=True) def _estimate_objective_lower_bound( diff --git a/test/acquisition/test_input_constructors.py b/test/acquisition/test_input_constructors.py index 16a8c659e0..3ed199ab15 100644 --- a/test/acquisition/test_input_constructors.py +++ b/test/acquisition/test_input_constructors.py @@ -390,7 +390,6 @@ def test_construct_inputs_qEI(self): self.assertTrue(torch.equal(kwargs["objective"].weights, objective.weights)) self.assertTrue(torch.equal(kwargs["X_pending"], X_pending)) self.assertIsNone(kwargs["sampler"]) - self.assertIsNone(kwargs["constraints"]) self.assertIsInstance(kwargs["eta"], float) self.assertTrue(kwargs["eta"] < 1) multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1) @@ -406,6 +405,20 @@ def test_construct_inputs_qEI(self): best_f=best_f_expected, ) self.assertEqual(kwargs["best_f"], best_f_expected) + # test passing constraints + outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]])) + constraints = get_outcome_constraint_transforms( + outcome_constraints=outcome_constraints + ) + kwargs = c( + model=mock_model, + training_data=self.blockX_multiY, + objective=objective, + X_pending=X_pending, + best_f=best_f_expected, + constraints=constraints, + ) + self.assertIs(kwargs["constraints"], constraints) # testing qLogEI input constructor log_constructor = get_acqf_input_constructor(qLogExpectedImprovement) @@ -415,6 +428,7 @@ def test_construct_inputs_qEI(self): objective=objective, X_pending=X_pending, best_f=best_f_expected, + constraints=constraints, ) # includes strict superset of kwargs tested above self.assertTrue(kwargs.items() <= log_kwargs.items()) @@ -423,6 +437,7 @@ def test_construct_inputs_qEI(self): self.assertEqual(log_kwargs["tau_max"], TAU_MAX) self.assertTrue("tau_relu" in log_kwargs) self.assertEqual(log_kwargs["tau_relu"], TAU_RELU) + self.assertIs(log_kwargs["constraints"], constraints) def test_construct_inputs_qNEI(self): c = get_acqf_input_constructor(qNoisyExpectedImprovement) @@ -441,11 +456,16 @@ def test_construct_inputs_qNEI(self): with self.assertRaisesRegex(ValueError, "Field `X` must be shared"): c(model=mock_model, training_data=self.multiX_multiY) X_baseline = torch.rand(2, 2) + outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]])) + constraints = get_outcome_constraint_transforms( + outcome_constraints=outcome_constraints + ) kwargs = c( model=mock_model, training_data=self.blockX_blockY, X_baseline=X_baseline, prune_baseline=False, + constraints=constraints, ) self.assertEqual(kwargs["model"], mock_model) self.assertIsNone(kwargs["objective"]) @@ -453,17 +473,19 @@ def test_construct_inputs_qNEI(self): self.assertIsNone(kwargs["sampler"]) self.assertFalse(kwargs["prune_baseline"]) self.assertTrue(torch.equal(kwargs["X_baseline"], X_baseline)) - self.assertIsNone(kwargs["constraints"]) self.assertIsInstance(kwargs["eta"], float) self.assertTrue(kwargs["eta"] < 1) + self.assertIs(kwargs["constraints"], constraints) # testing qLogNEI input constructor log_constructor = get_acqf_input_constructor(qLogNoisyExpectedImprovement) + log_kwargs = log_constructor( model=mock_model, training_data=self.blockX_blockY, X_baseline=X_baseline, prune_baseline=False, + constraints=constraints, ) # includes strict superset of kwargs tested above self.assertTrue(kwargs.items() <= log_kwargs.items()) @@ -472,6 +494,7 @@ def test_construct_inputs_qNEI(self): self.assertEqual(log_kwargs["tau_max"], TAU_MAX) self.assertTrue("tau_relu" in log_kwargs) self.assertEqual(log_kwargs["tau_relu"], TAU_RELU) + self.assertIs(log_kwargs["constraints"], constraints) def test_construct_inputs_qPI(self): c = get_acqf_input_constructor(qProbabilityOfImprovement) @@ -499,7 +522,6 @@ def test_construct_inputs_qPI(self): self.assertTrue(torch.equal(kwargs["X_pending"], X_pending)) self.assertIsNone(kwargs["sampler"]) self.assertEqual(kwargs["tau"], 1e-2) - self.assertIsNone(kwargs["constraints"]) self.assertIsInstance(kwargs["eta"], float) self.assertTrue(kwargs["eta"] < 1) multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1) @@ -507,6 +529,10 @@ def test_construct_inputs_qPI(self): self.assertEqual(kwargs["best_f"], best_f_expected) # Check explicitly specifying `best_f`. best_f_expected = best_f_expected - 1 # Random value. + outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]])) + constraints = get_outcome_constraint_transforms( + outcome_constraints=outcome_constraints + ) kwargs = c( model=mock_model, training_data=self.blockX_multiY, @@ -514,8 +540,10 @@ def test_construct_inputs_qPI(self): X_pending=X_pending, tau=1e-2, best_f=best_f_expected, + constraints=constraints, ) self.assertEqual(kwargs["best_f"], best_f_expected) + self.assertIs(kwargs["constraints"], constraints) def test_construct_inputs_qUCB(self): c = get_acqf_input_constructor(qUpperConfidenceBound) @@ -564,7 +592,7 @@ def test_construct_inputs_EHVI(self): model=mock_model, training_data=self.blockX_blockY, objective_thresholds=objective_thresholds, - outcome_constraints=mock.Mock(), + constraints=mock.Mock(), ) # test with Y_pmean supplied explicitly @@ -702,13 +730,16 @@ def test_construct_inputs_qEHVI(self): weights = torch.rand(2) obj = WeightedMCMultiOutputObjective(weights=weights) outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]])) + constraints = get_outcome_constraint_transforms( + outcome_constraints=outcome_constraints + ) X_pending = torch.rand(1, 2) kwargs = c( model=mm, training_data=self.blockX_blockY, objective_thresholds=objective_thresholds, objective=obj, - outcome_constraints=outcome_constraints, + constraints=constraints, X_pending=X_pending, alpha=0.05, eta=1e-2, @@ -723,11 +754,7 @@ def test_construct_inputs_qEHVI(self): Y_expected = mean[:1] * weights self.assertTrue(torch.equal(partitioning._neg_Y, -Y_expected)) self.assertTrue(torch.equal(kwargs["X_pending"], X_pending)) - cons_tfs = kwargs["constraints"] - self.assertEqual(len(cons_tfs), 1) - cons_eval = cons_tfs[0](mean) - cons_eval_expected = torch.tensor([-0.25, 0.5]) - self.assertTrue(torch.equal(cons_eval, cons_eval_expected)) + self.assertIs(kwargs["constraints"], constraints) self.assertEqual(kwargs["eta"], 1e-2) # Test check for block designs @@ -737,7 +764,7 @@ def test_construct_inputs_qEHVI(self): training_data=self.multiX_multiY, objective_thresholds=objective_thresholds, objective=obj, - outcome_constraints=outcome_constraints, + constraints=constraints, X_pending=X_pending, alpha=0.05, eta=1e-2, @@ -798,6 +825,9 @@ def test_construct_inputs_qNEHVI(self): X_baseline = torch.rand(2, 2) sampler = IIDNormalSampler(sample_shape=torch.Size([4])) outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]])) + constraints = get_outcome_constraint_transforms( + outcome_constraints=outcome_constraints + ) X_pending = torch.rand(1, 2) kwargs = c( model=mock_model, @@ -806,7 +836,7 @@ def test_construct_inputs_qNEHVI(self): objective=objective, X_baseline=X_baseline, sampler=sampler, - outcome_constraints=outcome_constraints, + constraints=constraints, X_pending=X_pending, eta=1e-2, prune_baseline=True, @@ -823,11 +853,7 @@ def test_construct_inputs_qNEHVI(self): self.assertIsInstance(sampler_, IIDNormalSampler) self.assertEqual(sampler_.sample_shape, torch.Size([4])) self.assertEqual(kwargs["objective"], objective) - cons_tfs_expected = get_outcome_constraint_transforms(outcome_constraints) - cons_tfs = kwargs["constraints"] - self.assertEqual(len(cons_tfs), 1) - test_Y = torch.rand(1, 2) - self.assertTrue(torch.equal(cons_tfs[0](test_Y), cons_tfs_expected[0](test_Y))) + self.assertIs(kwargs["constraints"], constraints) self.assertTrue(torch.equal(kwargs["X_pending"], X_pending)) self.assertEqual(kwargs["eta"], 1e-2) self.assertTrue(kwargs["prune_baseline"]) @@ -844,7 +870,7 @@ def test_construct_inputs_qNEHVI(self): training_data=self.blockX_blockY, objective_thresholds=objective_thresholds, objective=MultiOutputExpectation(n_w=3), - outcome_constraints=outcome_constraints, + constraints=constraints, ) for use_preprocessing in (True, False): obj = MultiOutputExpectation( diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index 162e2b03a2..0ec4b748ee 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -864,75 +864,134 @@ def test_compute_best_feasible_objective(self): tkwargs = {"dtype": dtype, "device": self.device} n = 5 X = torch.arange(n, **tkwargs).view(-1, 1) - means = torch.arange(n, **tkwargs).view(-1, 1) - samples = means - variances = torch.tensor( - [0.09, 0.25, 0.36, 0.25, 0.09], **tkwargs - ).view(-1, 1) - mm = MockModel( - MockPosterior(mean=means, variance=variances, samples=samples) - ) + for batch_shape, sample_shape in itertools.product( + (torch.Size([]), torch.Size([2])), + (torch.Size([1]), torch.Size([3])), + ): + means = torch.arange(n, **tkwargs).view(-1, 1) + if len(batch_shape) > 0: + view_means = means.view(1, *means.shape) + means = view_means.expand(batch_shape + means.shape) + if sample_shape[0] == 1: + samples = means.unsqueeze(0) + else: + samples = torch.stack([means, means + 1, means + 4], dim=0) + variances = torch.tensor( + [0.09, 0.25, 0.36, 0.25, 0.09], **tkwargs + ).view(-1, 1) + mm = MockModel(MockPosterior(mean=means, variance=variances)) - # testing all feasible points - obj = means.squeeze(-1) - constraints = [lambda samples: -torch.ones_like(samples[..., 0])] - best_f = compute_best_feasible_objective( - samples=means, obj=obj, constraints=constraints - ) - self.assertAllClose(best_f, obj.amax(dim=-1, keepdim=True)) - - # testing with some infeasible points - con_cutoff = 3.0 - best_f = compute_best_feasible_objective( - samples=means, - obj=obj, - constraints=[ - lambda samples: samples[..., 0] - (con_cutoff + 1 / 2) - ], - ) - # only first three points are feasible - self.assertAllClose(best_f, torch.tensor([con_cutoff], **tkwargs)) + # testing all feasible points + obj = samples.squeeze(-1) + constraints = [lambda samples: -torch.ones_like(samples[..., 0])] + best_f = compute_best_feasible_objective( + samples=samples, obj=obj, constraints=constraints + ) + self.assertAllClose(best_f, obj.amax(dim=-1, keepdim=True)) - # testing with no feasible points and infeasible obj - infeasible_obj = torch.tensor(torch.pi, **tkwargs) - best_f = compute_best_feasible_objective( - samples=means, - obj=obj, - constraints=[lambda X: torch.ones_like(X[..., 0])], - infeasible_obj=infeasible_obj, - ) - self.assertAllClose(best_f, infeasible_obj.unsqueeze(0)) + # testing with some infeasible points + con_cutoff = 3.0 + best_f = compute_best_feasible_objective( + samples=samples, + obj=obj, + constraints=[ + lambda samples: samples[..., 0] - (con_cutoff + 1 / 2) + ], + model=mm, + X_baseline=X, + ) - # testing with no feasible points and not infeasible obj - def objective(Y, X): - return Y.squeeze(-1) - 5.0 + if sample_shape[0] == 3: + # under some samples, all baseline points are infeasible, so + # the best_f is set to the negative infeasible cost for + # for samples where no point is feasible + expected_best_f = torch.tensor( + [ + 3.0, + 3.0, + -get_infeasible_cost( + X=X, + model=mm, + ).item(), + ], + **tkwargs, + ).view(-1, 1) + if len(batch_shape) > 0: + expected_best_f = expected_best_f.unsqueeze(1) + expected_best_f = expected_best_f.expand( + *sample_shape, *batch_shape, 1 + ) + else: + expected_best_f = torch.full( + sample_shape + batch_shape + torch.Size([1]), + con_cutoff, + **tkwargs, + ) + self.assertAllClose(best_f, expected_best_f) + # test some feasible points with infeasible obi + if sample_shape[0] == 3: + best_f = compute_best_feasible_objective( + samples=samples, + obj=obj, + constraints=[ + lambda samples: samples[..., 0] - (con_cutoff + 1 / 2) + ], + infeasible_obj=torch.ones(1, **tkwargs), + ) + expected_best_f[-1] = 1 + self.assertAllClose(best_f, expected_best_f) - best_f = compute_best_feasible_objective( - samples=means, - obj=obj, - constraints=[lambda X: torch.ones_like(X[..., 0])], - model=mm, - X_baseline=X, - objective=objective, - ) - self.assertAllClose( - best_f, -get_infeasible_cost(X=X, model=mm, objective=objective) - ) + # testing with no feasible points and infeasible obj + infeasible_obj = torch.tensor(torch.pi, **tkwargs) + expected_best_f = torch.full( + sample_shape + batch_shape + torch.Size([1]), + torch.pi, + **tkwargs, + ) - with self.assertRaisesRegex(ValueError, "Must specify `model`"): best_f = compute_best_feasible_objective( - samples=means, + samples=samples, obj=obj, constraints=[lambda X: torch.ones_like(X[..., 0])], - X_baseline=X, + infeasible_obj=infeasible_obj, ) - with self.assertRaisesRegex(ValueError, "Must specify `X_baseline`"): + self.assertAllClose(best_f, expected_best_f) + + # testing with no feasible points and not infeasible obj + def objective(Y, X): + return Y.squeeze(-1) - 5.0 + best_f = compute_best_feasible_objective( - samples=means, + samples=samples, obj=obj, constraints=[lambda X: torch.ones_like(X[..., 0])], model=mm, + X_baseline=X, + objective=objective, + ) + expected_best_f = torch.full( + sample_shape + batch_shape + torch.Size([1]), + -get_infeasible_cost(X=X, model=mm, objective=objective).item(), + **tkwargs, ) + self.assertAllClose(best_f, expected_best_f) + + with self.assertRaisesRegex(ValueError, "Must specify `model`"): + best_f = compute_best_feasible_objective( + samples=means, + obj=obj, + constraints=[lambda X: torch.ones_like(X[..., 0])], + X_baseline=X, + ) + with self.assertRaisesRegex( + ValueError, "Must specify `X_baseline`" + ): + best_f = compute_best_feasible_objective( + samples=means, + obj=obj, + constraints=[lambda X: torch.ones_like(X[..., 0])], + model=mm, + ) def test_get_infeasible_cost(self): for dtype in (torch.float, torch.double):