-
Notifications
You must be signed in to change notification settings - Fork 417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Variational Bayesian last layer models as surrogate models #2754
base: main
Are you sure you want to change the base?
Changes from all commits
f599471
6b82184
1b0f899
2606ac3
873163f
f4563ca
0944f57
cf9e6b1
cbf568e
2541cef
348749d
66ff26d
ad2fb22
9119db3
d01e1fd
353196c
630cbc2
8ad0620
92d3247
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
import scipy | ||
|
||
import torch | ||
|
||
from botorch.logging import logger | ||
|
||
from botorch_community.models.blls import AbstractBLLModel | ||
from torch.func import grad | ||
|
||
|
||
class BLLMaxPosteriorSampling: | ||
def __init__( | ||
self, | ||
model: AbstractBLLModel, | ||
num_restarts: int = 10, | ||
bounds: torch.Tensor | None = None, | ||
discrete_inputs: bool = False, | ||
): | ||
""" | ||
Implements Maximum Posterior Sampling for Bayesian Linear Last (VBLL) models. | ||
|
||
This class provides functionality to sample from the posterior distribution of a | ||
BLL model, with optional optimization to refine the sampling process. | ||
|
||
Args: | ||
model: The VBLL model from which posterior samples are drawn. Must be an | ||
instance of `AbstractBLLModel`. | ||
num_restarts: Number of restarts for optimization-based sampling. | ||
Defaults to 10. | ||
bounds: Tensor of shape (2, num_inputs) specifying the lower and upper | ||
bounds for sampling. If None, defaults to [(0, 1)] for each input | ||
dimension. | ||
discrete_inputs: If True, assumes the input space is discrete and will be | ||
provided in __call__. Defaults to False. | ||
|
||
Raises: | ||
ValueError: | ||
If the provided `model` is not an instance of `AbstractBLLModel`. | ||
|
||
Notes: | ||
- If `bounds` is not provided, the default range [0,1] is assumed for each | ||
input dimension. | ||
""" | ||
if not isinstance(model, AbstractBLLModel): | ||
raise ValueError( | ||
f"Model must be an instance of AbstractBLLModel, is {type(model)}" | ||
) | ||
|
||
self.model = model | ||
self.device = model.device | ||
self.discrete_inputs = discrete_inputs | ||
self.num_restarts = num_restarts | ||
|
||
if bounds is None: | ||
# Default bounds [0,1] for each input dimension | ||
self.bounds = [(0, 1)] * self.model.num_inputs | ||
self.lb = torch.zeros( | ||
self.model.num_inputs, dtype=torch.float64, device=torch.device("cpu") | ||
) | ||
self.ub = torch.ones( | ||
self.model.num_inputs, dtype=torch.float64, device=torch.device("cpu") | ||
) | ||
else: | ||
# Ensure bounds are on CPU for compatibility with scipy.optimize.minimize | ||
self.lb = bounds[0, :].cpu() | ||
self.ub = bounds[1, :].cpu() | ||
self.bounds = [tuple(bound) for bound in bounds.T.cpu().tolist()] | ||
|
||
def __call__( | ||
self, X_cand: torch.Tensor = None, num_samples: int = 1 | ||
) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a docstring describing this method, including the expected input and return shapes of the tensors? |
||
if self.discrete_inputs and X_cand is None: | ||
raise ValueError("X_cand must be provided if `discrete_inputs` is True.") | ||
|
||
if X_cand is not None and not self.discrete_inputs: | ||
raise ValueError("X_cand is provided but `discrete_inputs` is False.") | ||
|
||
X_next = torch.empty( | ||
num_samples, self.model.num_inputs, dtype=torch.float64, device=self.device | ||
) | ||
|
||
# get max of sampled functions at candidate points for each function | ||
for i in range(num_samples): | ||
f = self.model.sample() | ||
|
||
if self.discrete_inputs: | ||
# evaluate sample path at candidate points and select best | ||
Y_cand = f(X_cand) | ||
else: | ||
# optimize sample path | ||
X_cand, Y_cand = _optimize_sample_path( | ||
f=f, | ||
num_restarts=self.num_restarts, | ||
bounds=self.bounds, | ||
lb=self.lb, | ||
ub=self.ub, | ||
device=self.device, | ||
) | ||
|
||
# select the best candidate | ||
X_next[i, :] = X_cand[Y_cand.argmax()] | ||
|
||
# ensure that the next point is within the bounds, | ||
# scipy minimize can sometimes return points outside the bounds | ||
X_next = torch.clamp(X_next, self.lb.to(self.device), self.ub.to(self.device)) | ||
return X_next | ||
|
||
|
||
def _optimize_sample_path( | ||
f: torch.nn.Module, | ||
num_restarts: int, | ||
bounds: list[tuple[float, float]], | ||
lb: torch.Tensor, | ||
ub: torch.Tensor, | ||
device: torch.device, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
"""Helper function to optimize the sample path of a BLL network. | ||
|
||
Args: | ||
f: The sample to optimize. | ||
num_restarts: Number of restarts for optimization-based sampling. | ||
bounds: List of tuples specifying the lower and upper bounds for each input | ||
dimension. | ||
lb: Lower bounds for each input dimension. | ||
ub: Upper bounds for each input dimension. | ||
device: Device on which to store the candidate points. | ||
|
||
Returns: | ||
Candidate points and corresponding function values. | ||
""" | ||
X_cand = torch.empty(num_restarts, f.num_inputs, dtype=torch.float64, device=device) | ||
Y_cand = torch.empty( | ||
num_restarts, f.num_outputs, dtype=torch.float64, device=device | ||
) | ||
|
||
# create numpy wrapper around the sampled function, note we aim to maximize | ||
def func(x): | ||
return -f(torch.from_numpy(x).to(device)).detach().cpu().numpy() | ||
|
||
# get gradient and create wrapper | ||
grad_f = grad(lambda x: f(x).mean()) | ||
|
||
def grad_func(x): | ||
return -grad_f(torch.from_numpy(x).to(device)).detach().cpu().numpy() | ||
|
||
# generate random initial conditions | ||
x0s = np.random.rand(num_restarts, f.num_inputs) | ||
|
||
for j in range(num_restarts): | ||
# map to bounds | ||
x0 = lb + (ub - lb) * x0s[j] | ||
|
||
# optimize sample path | ||
res = scipy.optimize.minimize( | ||
func, x0, jac=grad_func, bounds=bounds, method="L-BFGS-B" | ||
) | ||
|
||
if not res.success: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work reliably? I've run into cases where the line search fails in odd ways sometimes, if there is a chance of this happening in any of the samples and we use a lot of samples then this could cause frequent failures. If this doesn't happen then great, o/w we may want to handle this, possibly by retrying the optimization There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, actually very reliably, so I have not had any problems with it. Should we still add the additional "safety layer"? Though I suspect that if this fails, something in the model is off. May be nice to include this pointer in an error maybe? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By this pointer you mean my comment above? I guess right now you're not erroring out but just warning, so that should be fine. Though if the majority of of the |
||
logger.warning(f"Optimization failed with message: {res.message}") | ||
|
||
# store the candidate | ||
X_cand[j, :] = torch.from_numpy(res.x).to(dtype=torch.float64) | ||
Y_cand[j] = torch.tensor([-res.fun], dtype=torch.float64) | ||
|
||
return X_cand, Y_cand | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
import torch | ||
import torch.nn as nn | ||
from botorch.models.model import Model | ||
from botorch_community.models.vbll_helper import DenseNormal, Normal | ||
from torch import Tensor | ||
|
||
|
||
class AbstractBLLModel(Model, ABC): | ||
def __init__(self): | ||
"""Abstract class for Bayesian Last Layer (BLL) models.""" | ||
super().__init__() | ||
self.model = None | ||
|
||
@property | ||
def num_outputs(self) -> int: | ||
return self.model.num_outputs | ||
|
||
@property | ||
def num_inputs(self): | ||
return self.model.num_inputs | ||
|
||
@property | ||
def device(self): | ||
return self.model.device | ||
|
||
@abstractmethod | ||
def __call__(self, X: Tensor) -> Normal | DenseNormal: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def fit(self, *args, **kwargs): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def sample(self, sample_shape: torch.Size | None = None) -> nn.Module: | ||
raise NotImplementedError | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a
SamplingStrategy
interface here that that defines the API for this (and subclassestorch.nn.Module
) but doesn't support not providing anX_cand
. I assume this is why you didn't subclass from that?I don't think it's necessary to use this if it would require changes to
SamplingStrategy
but maybe leave a comment here that we may want to consider doing that in the future.