-
Notifications
You must be signed in to change notification settings - Fork 430
[Feature] Variational Bayesian last layer models as surrogate models #2754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
f599471
add vbll surrogate and notebook
brunzema 6b82184
update notebook and base implementation
brunzema 1b0f899
update optim config
brunzema 2606ac3
add tests for vbll model
brunzema 873163f
update tutorial for vblls
brunzema f4563ca
update docsting vbll
brunzema 0944f57
remove vbll repo dependency and add relevant code
brunzema cf9e6b1
update tutorial including a comparison between TS and logEI to demons…
brunzema cbf568e
add vbll helper module for standard regression layer
brunzema 2541cef
include PR comments, fix docstrings
brunzema 348749d
add additional test for forward methods of VBLL and sample network
brunzema 66ff26d
update tutorial
brunzema ad2fb22
minor refactors
brunzema 9119db3
Merge branch 'pytorch:main' into feature/add_vbll_surrogates
brunzema d01e1fd
Merge branch 'pytorch:main' into feature/add_vbll_surrogates
brunzema 353196c
include PR comments, address flake8 and µfmt
brunzema 630cbc2
resolve cyclic imports by pulling out the abstract BLL
brunzema 8ad0620
fix import error bll TS
brunzema 92d3247
clean up docstrings, sorry
brunzema 75b3993
Merge branch 'pytorch:main' into feature/add_vbll_surrogates
brunzema dd30bf9
add test for bll posterior sampling
brunzema 2858347
add tests for bll posterior
brunzema 20c4fcd
modify vbll tests to include all possible parameterizations
brunzema ae1dca7
increase test coverage for vblls
brunzema d9f5d2a
Merge branch 'pytorch:main' into feature/add_vbll_surrogates
brunzema 7380910
add mock patch for VBLL TS + add tests for vbll helper distributions
brunzema File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
import scipy | ||
|
||
import torch | ||
|
||
from botorch.logging import logger | ||
|
||
from botorch_community.models.blls import AbstractBLLModel | ||
from torch.func import grad | ||
|
||
|
||
class BLLMaxPosteriorSampling: | ||
"""TODO: Integrate into `SamplingStrategy` (note, here we dont require | ||
X_cand to be passed but optimize the sample paths numerically as a default).""" | ||
|
||
def __init__( | ||
self, | ||
model: AbstractBLLModel, | ||
num_restarts: int = 10, | ||
bounds: torch.Tensor | None = None, | ||
discrete_inputs: bool = False, | ||
): | ||
""" | ||
Implements Maximum Posterior Sampling for Bayesian Linear Last (VBLL) models. | ||
|
||
This class provides functionality to sample from the posterior distribution of a | ||
BLL model, with optional optimization to refine the sampling process. | ||
|
||
Args: | ||
model: The VBLL model from which posterior samples are drawn. Must be an | ||
instance of `AbstractBLLModel`. | ||
num_restarts: Number of restarts for optimization-based sampling. | ||
Defaults to 10. | ||
bounds: Tensor of shape `2 x d` specifying the lower and upper | ||
bounds for sampling. If None, defaults to [(0, 1)] for each input | ||
dimension. | ||
discrete_inputs: If True, assumes the input space is discrete and will be | ||
provided in __call__. Defaults to False. | ||
|
||
Raises: | ||
ValueError: | ||
If the provided `model` is not an instance of `AbstractBLLModel`. | ||
|
||
Notes: | ||
- If `bounds` is not provided, the default range [0,1] is assumed for each | ||
input dimension. | ||
""" | ||
if not isinstance(model, AbstractBLLModel): | ||
raise ValueError( | ||
f"Model must be an instance of AbstractBLLModel, is {type(model)}" | ||
) | ||
|
||
self.model = model | ||
self.device = model.device | ||
self.discrete_inputs = discrete_inputs | ||
self.num_restarts = num_restarts | ||
|
||
if bounds is None: | ||
# Default bounds [0,1] for each input dimension | ||
self.bounds = [(0, 1)] * self.model.num_inputs | ||
self.lb = torch.zeros( | ||
self.model.num_inputs, dtype=torch.float64, device=torch.device("cpu") | ||
) | ||
self.ub = torch.ones( | ||
self.model.num_inputs, dtype=torch.float64, device=torch.device("cpu") | ||
) | ||
else: | ||
# Ensure bounds are on CPU for compatibility with scipy.optimize.minimize | ||
self.lb = bounds[0, :].cpu() | ||
self.ub = bounds[1, :].cpu() | ||
self.bounds = [tuple(bound) for bound in bounds.T.cpu().tolist()] | ||
|
||
def __call__( | ||
self, X_cand: torch.Tensor = None, num_samples: int = 1 | ||
) -> torch.Tensor: | ||
Balandat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
r"""Sample from a Bayesian last layer model posterior. | ||
|
||
Args: | ||
X_cand: A `N x d`-dim Tensor from which to sample (in the `N` | ||
dimension) according to the maximum posterior value under the objective. | ||
NOTE: X_cand is only accepted if `discrete_inputs` is `True`! | ||
num_samples: The number of samples to draw. | ||
|
||
Returns: | ||
A `num_samples x d`-dim Tensor of maximum posterior values from the model, | ||
where `X[i, :]` is the `i`-th sample. | ||
""" | ||
if self.discrete_inputs and X_cand is None: | ||
raise ValueError("X_cand must be provided if `discrete_inputs` is True.") | ||
|
||
if X_cand is not None and not self.discrete_inputs: | ||
raise ValueError("X_cand is provided but `discrete_inputs` is False.") | ||
|
||
X_next = torch.empty( | ||
num_samples, self.model.num_inputs, dtype=torch.float64, device=self.device | ||
) | ||
|
||
# get max of sampled functions at candidate points for each function | ||
for i in range(num_samples): | ||
f = self.model.sample() | ||
|
||
if self.discrete_inputs: | ||
# evaluate sample path at candidate points and select best | ||
Y_cand = f(X_cand) | ||
else: | ||
# optimize sample path | ||
X_cand, Y_cand = _optimize_sample_path( | ||
f=f, | ||
num_restarts=self.num_restarts, | ||
bounds=self.bounds, | ||
lb=self.lb, | ||
ub=self.ub, | ||
device=self.device, | ||
) | ||
|
||
# select the best candidate | ||
X_next[i, :] = X_cand[Y_cand.argmax()] | ||
|
||
# ensure that the next point is within the bounds, | ||
# scipy minimize can sometimes return points outside the bounds | ||
X_next = torch.clamp(X_next, self.lb.to(self.device), self.ub.to(self.device)) | ||
return X_next | ||
|
||
|
||
def _optimize_sample_path( | ||
f: torch.nn.Module, | ||
num_restarts: int, | ||
bounds: list[tuple[float, float]], | ||
lb: torch.Tensor, | ||
ub: torch.Tensor, | ||
device: torch.device, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
"""Helper function to optimize the sample path of a BLL network. | ||
|
||
Args: | ||
f: The sample to optimize. | ||
num_restarts: Number of restarts for optimization-based sampling. | ||
bounds: List of tuples specifying the lower and upper bounds for each input | ||
dimension. | ||
lb: Lower bounds for each input dimension. | ||
ub: Upper bounds for each input dimension. | ||
device: Device on which to store the candidate points. | ||
|
||
Returns: | ||
Candidate points and corresponding function values as a tuple with a | ||
`num_restarts x d`-dim Tensor and `num_restarts x num_outputs`-dim Tensor, | ||
respectively. | ||
""" | ||
X_cand = torch.empty(num_restarts, f.num_inputs, dtype=torch.float64, device=device) | ||
Y_cand = torch.empty( | ||
num_restarts, f.num_outputs, dtype=torch.float64, device=device | ||
) | ||
|
||
# create numpy wrapper around the sampled function, note we aim to maximize | ||
def func(x): | ||
return -f(torch.from_numpy(x).to(device)).detach().cpu().numpy() | ||
|
||
# get gradient and create wrapper | ||
grad_f = grad(lambda x: f(x).mean()) | ||
|
||
def grad_func(x): | ||
return -grad_f(torch.from_numpy(x).to(device)).detach().cpu().numpy() | ||
|
||
# generate random initial conditions | ||
x0s = np.random.rand(num_restarts, f.num_inputs) | ||
|
||
lb = lb.numpy() | ||
ub = ub.numpy() | ||
|
||
optimization_successful = False | ||
for j in range(num_restarts): | ||
# map to bounds | ||
x0 = lb + (ub - lb) * x0s[j] | ||
|
||
# optimize sample path | ||
res = scipy.optimize.minimize( | ||
func, x0, jac=grad_func, bounds=bounds, method="L-BFGS-B" | ||
) | ||
|
||
# check if optimization was successful | ||
if res.success: | ||
optimization_successful = True | ||
if not res.success: | ||
Balandat marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.warning(f"Optimization failed with message: {res.message}") | ||
|
||
# store the candidate | ||
X_cand[j, :] = torch.from_numpy(res.x).to(dtype=torch.float64) | ||
Y_cand[j] = torch.tensor([-res.fun], dtype=torch.float64) | ||
|
||
if not optimization_successful: | ||
raise RuntimeError("All optimization attempts on the sample path failed.") | ||
|
||
return X_cand, Y_cand |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
import torch | ||
import torch.nn as nn | ||
from botorch.models.model import Model | ||
from botorch_community.models.vbll_helper import DenseNormal, Normal | ||
from torch import Tensor | ||
|
||
|
||
class AbstractBLLModel(Model, ABC): | ||
def __init__(self): | ||
"""Abstract class for Bayesian Last Layer (BLL) models.""" | ||
super().__init__() | ||
self.model = None | ||
|
||
@property | ||
def num_outputs(self) -> int: | ||
return self.model.num_outputs | ||
|
||
@property | ||
def num_inputs(self): | ||
return self.model.num_inputs | ||
|
||
@property | ||
def device(self): | ||
return self.model.device | ||
|
||
@abstractmethod | ||
def __call__(self, X: Tensor) -> Normal | DenseNormal: | ||
pass # pragma: no cover | ||
|
||
@abstractmethod | ||
def fit(self, *args, **kwargs): | ||
pass # pragma: no cover | ||
|
||
@abstractmethod | ||
def sample(self, sample_shape: torch.Size | None = None) -> nn.Module: | ||
pass # pragma: no cover |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.