From a546b5e7187a3f8d8af403d6bdb78f536daebe3a Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Mon, 7 Oct 2024 10:07:55 -0700 Subject: [PATCH] Add a helper for evaluating feasibility of a set of points (#2565) Summary: Adds a helper for evaluating the feasibility of intra-point parameter constraints on a given tensor. Differential Revision: D63909338 --- botorch/optim/parameter_constraints.py | 57 +++++++++++++- test/optim/test_parameter_constraints.py | 94 ++++++++++++++++++++++++ 2 files changed, 150 insertions(+), 1 deletion(-) diff --git a/botorch/optim/parameter_constraints.py b/botorch/optim/parameter_constraints.py index 0f85235561..f7ecd82433 100644 --- a/botorch/optim/parameter_constraints.py +++ b/botorch/optim/parameter_constraints.py @@ -11,7 +11,6 @@ from __future__ import annotations from collections.abc import Callable - from functools import partial from typing import Union @@ -26,6 +25,7 @@ str, Union[str, Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray]] ] NLC_TOL = -1e-6 +INTRA_POINT_CONST_ERR: str = "Only intra-point constraints are supported." def make_scipy_bounds( @@ -601,3 +601,58 @@ def make_scipy_nonlinear_inequality_constraints( shapeX=shapeX, ) return scipy_nonlinear_inequality_constraints + + +def evaluate_feasibility( + X: Tensor, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None, +) -> Tensor: + r"""Evaluate feasibility of a set of points. Only supports intra-point constraints. + + Args: + X: A tensor of points of shape `batch_shape x d`. + inequality_constraints: A list of tuples (indices, coefficients, rhs), + with each tuple encoding an inequality constraint of the form + `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and + `coefficients` should be torch tensors. See the docstring of + `make_scipy_linear_constraints` for an example. + equality_constraints: A list of tuples (indices, coefficients, rhs), + with each tuple encoding an equality constraint of the form + `\sum_i (X[indices[i]] * coefficients[i]) = rhs`. See the docstring of + `make_scipy_linear_constraints` for an example. + nonlinear_inequality_constraints: A list of tuples representing the nonlinear + inequality constraints. The first element in the tuple is a callable + representing a constraint of the form `callable(x) >= 0`. The `callable()` + takes in an one-dimensional tensor of shape `d` and returns a scalar. The + second element is a boolean, indicating if it is an intra-point or + inter-point constraint (`True` for intra-point. `False` for + inter-point). Only `True` is supported here. For more information on + intra-point vs inter-point constraints, see the docstring of the + `inequality_constraints` argument to `optimize_acqf()`. + + Returns: + A boolean tensor of shape `batch` denoting whether each point is feasible. + """ + is_feasible = torch.ones(X.shape[:-1], device=X.device, dtype=torch.bool) + if inequality_constraints is not None: + for idx, coef, rhs in inequality_constraints: + if idx.ndim != 1: + raise UnsupportedError(INTRA_POINT_CONST_ERR) + is_feasible &= (X[..., idx] * coef).sum(dim=-1) >= rhs + if equality_constraints is not None: + for idx, coef, rhs in equality_constraints: + if idx.ndim != 1: + raise UnsupportedError(INTRA_POINT_CONST_ERR) + is_feasible &= (X[..., idx] * coef).sum(dim=-1) == rhs + if nonlinear_inequality_constraints is not None: + for const, intra in nonlinear_inequality_constraints: + if not intra: + raise UnsupportedError(INTRA_POINT_CONST_ERR) + is_feasible &= torch.tensor( + [const(x) >= NLC_TOL for x in X.view(-1, X.shape[-1])], + device=X.device, + dtype=torch.bool, + ).view_as(is_feasible) + return is_feasible diff --git a/test/optim/test_parameter_constraints.py b/test/optim/test_parameter_constraints.py index 435c99fcb0..5ce6cdc4f8 100644 --- a/test/optim/test_parameter_constraints.py +++ b/test/optim/test_parameter_constraints.py @@ -17,6 +17,8 @@ _make_linear_constraints, _make_nonlinear_constraints, eval_lin_constraint, + evaluate_feasibility, + INTRA_POINT_CONST_ERR, lin_constraint_jac, make_scipy_bounds, make_scipy_linear_constraints, @@ -528,6 +530,98 @@ def test_generate_unfixed_lin_constraints(self): eq=eq, ) + def test_evaluate_feasibility_intra_point_checks(self) -> None: + # Check that `evaluate_feasibility` raises an error if inter-point + # constraints are used. + X = torch.ones(3, 2, device=self.device) + inter_cons = ( + torch.tensor([[0, 0], [1, 0]], device=self.device), + torch.tensor([1.0, -1.0], device=self.device), + 0, + ) + for const_arg in ( + {"inequality_constraints": [inter_cons]}, + {"equality_constraints": [inter_cons]}, + {"nonlinear_inequality_constraints": [(None, False)]}, + ): + with self.assertRaisesRegex(UnsupportedError, INTRA_POINT_CONST_ERR): + evaluate_feasibility(X=X, **const_arg) + + def test_evaluate_feasibility(self) -> None: + # Check that the feasibility is evaluated correctly. + X = torch.tensor( + [ + [[1.0, 1.0, 1.0]], + [[1.0, 1.0, 3.0]], + [[2.0, 2.0, 1.0]], + [[2.0, 2.0, 5.0]], + [[3.0, 3.0, 3.0]], + ], + device=self.device, + ) + # X[..., 2] * 4 >= 5. + inequality_constraints = [ + ( + torch.tensor([2], device=self.device), + torch.tensor([4], device=self.device), + 5.0, + ) + ] + # X[..., 0] + X[..., 1] == 4. + equality_constraints = [ + ( + torch.tensor([0, 1], device=self.device), + torch.ones(2, device=self.device), + 4.0, + ) + ] + + # sum(X, dim=-1) < 4. + def nlc1(x): + return 4 - x.sum(dim=-1) + + # Only inequality. + self.assertAllClose( + evaluate_feasibility( + X=X, + inequality_constraints=inequality_constraints, + ), + torch.tensor( + [[False], [True], [False], [True], [True]], device=self.device + ), + ) + # Only equality. + self.assertAllClose( + evaluate_feasibility( + X=X, + equality_constraints=equality_constraints, + ), + torch.tensor( + [[False], [False], [True], [True], [False]], device=self.device + ), + ) + # Both inequality and equality. + self.assertAllClose( + evaluate_feasibility( + X=X, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ), + torch.tensor( + [[False], [False], [False], [True], [False]], device=self.device + ), + ) + # Nonlinear inequality. + self.assertAllClose( + evaluate_feasibility( + X=X, + nonlinear_inequality_constraints=[(nlc1, True)], + ), + torch.tensor( + [[True], [False], [False], [False], [False]], device=self.device + ), + ) + class TestMakeScipyBounds(BotorchTestCase): def test_make_scipy_bounds(self):