Skip to content

Commit 6565d3f

Browse files
jduerholtfacebook-github-bot
authored andcommitted
Probabilities of Feasibility for Classifier based constraints in Acquisition Functions (pytorch#2776)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation Using classifiers as output constraints in MC based acquisition functions is a topic discussed at least in pytorch#725 and in pytorch#2700. The current solution is to take the probabilities from the classifier and to unsigmoid them. This is a very unintuitive approach, especially as sometimes a sigmoid and sometimes a fatmoid is used. This PR introduces a new attribute for `SampleReducingMCAcquisitionFunction`s named `probabilities_of_feasibility` that expects a callable that returns a tensor holding values between zero and one, where one means feasible and zero infeasible. Currently, it is only implemented in the abstract `SampleReducingMCAcquisitionFunction` using the additional attribute. As the `constraints` argument is just a special case of the `probabilities_of_feasibility` argument, where the output of the callable is not directly applied to the objective but further processed by a sigmoid or fatmoid one could also think about uniting both functionalities into one argument, and modify `fat` to `List[bool | None] | bool` that indicates if a fatmoid, a sigmoid or nothing is applied. When the user just provides a bool, it applies either a fatmoid or sigmoid for all. This approach would also have the advantage that only `compute_smoothed_feasibility_indicator` needs to be modified and almost nothing for the individual acqfs (besides updating the types for `constraints`.) Furthermore, it follows the approach that we took when we implemented individual `eta`s for the constraints. So I would favor this one in contrast to the one actually outlined in the code ;) I am looking forward to you ideas on this. SebastianAment: In pytorch#2700, you mention that from your perspective the `probabilities_of_feasibility` would not be applied on a per sample basis as the regular constraints. Why? Also in the community notebook by FrankWanger using the unsigmoid trick it is applied on a per sample basis. I would keep it on the per sample basis and if a classifier for some reason do not returns the probabilities on a per sample basis, it would be the task of the user to expand the tensor accordingly. What do you think? ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: pytorch#2776 Test Plan: Unit test, but not yet added due to the draft status and pending architecture choices. cc: Balandat Reviewed By: saitcakmak Differential Revision: D72342434 Pulled By: Balandat fbshipit-source-id: 6fe6d7201d1a9388dde90e0a46f087f06dba958a
1 parent 86ef733 commit 6565d3f

File tree

4 files changed

+207
-94
lines changed

4 files changed

+207
-94
lines changed

botorch/acquisition/monte_carlo.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def __init__(
189189
q_reduction: SampleReductionProtocol = torch.amax,
190190
constraints: list[Callable[[Tensor], Tensor]] | None = None,
191191
eta: Tensor | float = 1e-3,
192-
fat: bool = False,
192+
fat: list[bool | None] | bool = False,
193193
):
194194
r"""Constructor of SampleReducingMCAcquisitionFunction.
195195
@@ -228,7 +228,9 @@ def __init__(
228228
approximation to the constraint indicators. For more details, on this
229229
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
230230
fat: Wether to apply a fat-tailed smooth approximation to the feasibility
231-
indicator or the canonical sigmoid approximation.
231+
indicator or the canonical sigmoid approximation. For more details,
232+
on this parameter, see the docs of
233+
`compute_smoothed_feasibility_indicator`.
232234
"""
233235
if constraints is not None and isinstance(objective, ConstrainedMCObjective):
234236
raise ValueError(

botorch/utils/objective.py

+42-12
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def compute_smoothed_feasibility_indicator(
136136
samples: Tensor,
137137
eta: Tensor | float,
138138
log: bool = False,
139-
fat: bool = False,
139+
fat: list[bool | None] | bool = False,
140140
) -> Tensor:
141141
r"""Computes the smoothed feasibility indicator of a list of constraints.
142142
@@ -149,33 +149,63 @@ def compute_smoothed_feasibility_indicator(
149149
150150
Args:
151151
constraints: A list of callables, each mapping a Tensor of size `b x q x m`
152-
to a Tensor of size `b x q`, where negative values imply feasibility.
153-
This callable must support broadcasting. Only relevant for multi-
154-
output models (`m` > 1).
152+
to a Tensor of size `b x q`. The `fat` keyword defines how the callable
153+
is further processed. By default a sigmoid or fatmoid transformation is
154+
applied where negative values imply feasibility.
155+
The applied transformation maps the feasibility indicator of the
156+
constraint from the interval [-inf, inf] to the interval [0, 1].
157+
If `None` is provided for `fat`, no transformation is applied and it
158+
is expected that the constraint callable delivers values in the
159+
interval [0, 1] without further processing that can be interpreted as
160+
probabilities of feasibility directly. This is especially useful
161+
for using classifiers as constraints. The callable must support
162+
broadcasting. Only relevant for multi-output models (`m` > 1).
155163
samples: A `n_samples x b x q x m` Tensor of samples drawn from the posterior.
156-
eta: The temperature parameter for the sigmoid function. Can be either a float
157-
or a 1-dim tensor. In case of a float the same eta is used for every
158-
constraint in constraints. In case of a tensor the length of the tensor
159-
must match the number of provided constraints. The i-th constraint is
160-
then estimated with the i-th eta value.
164+
eta: The temperature parameter for the sigmoid/fatmoid function. Can be either
165+
a float or a 1-dim tensor. In case of a float the same eta is used for
166+
every constraint in constraints. In case of a tensor the length of the
167+
tensor must match the number of provided constraints. The i-th constraint
168+
is then estimated with the i-th eta value. In case no fatmoid/sigmoid is
169+
applied, eta is ignored.
161170
log: Toggles the computation of the log-feasibility indicator.
162171
fat: Toggles the computation of the fat-tailed feasibility indicator.
172+
Can be either a list or a boolean. If case of a boolean, the same
173+
feasibility indicator is used for all constraints. If a list is provided,
174+
the length of the list must match the number of provided constraints.
175+
The i-th constraint is then associated with the i-th fat value. In case,
176+
the i-th fat value is `None`, no fatmoid/sigmoid transformation is
177+
applied to the i-th constraint and it is assumed that the constraint
178+
by itself delivers values in the interval [0, 1]. This is especially useful
179+
for using classifiers as constraints. If a boolean is provided and its
180+
value is `True`, a fatmoid transformation is applied, if its value is
181+
`False`, a sigmoid transformation is applied.
182+
163183
164184
Returns:
165185
A `n_samples x b x q`-dim tensor of feasibility indicator values.
166186
"""
167187
if type(eta) is not Tensor:
168188
eta = torch.full((len(constraints),), eta)
189+
if type(fat) is not list:
190+
fat = [fat] * len(constraints)
169191
if len(eta) != len(constraints):
170192
raise ValueError(
171193
"Number of provided constraints and number of provided etas do not match."
172194
)
195+
if len(fat) != len(constraints):
196+
raise ValueError(
197+
"Number of provided constraints and number of provided fats do not match."
198+
)
173199
if not (eta > 0).all():
174200
raise ValueError("eta must be positive.")
175201
is_feasible = torch.zeros_like(samples[..., 0])
176-
log_sigmoid = log_fatmoid if fat else logexpit
177-
for constraint, e in zip(constraints, eta):
178-
is_feasible = is_feasible + log_sigmoid(-constraint(samples) / e)
202+
203+
for constraint, eta_, fat_ in zip(constraints, eta, fat):
204+
if fat_ is None:
205+
is_feasible = is_feasible + constraint(samples).log()
206+
else:
207+
log_sigmoid = log_fatmoid if fat_ else logexpit
208+
is_feasible = is_feasible + log_sigmoid(-constraint(samples) / eta_)
179209

180210
return is_feasible if log else is_feasible.exp()
181211

0 commit comments

Comments
 (0)