Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Variational Bayesian last layer models as surrogate models #2754

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions botorch_community/acquisition/bll_thompson_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

Check warning on line 6 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L6

Added line #L6 was not covered by tests

import numpy as np
import scipy

Check warning on line 9 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L8-L9

Added lines #L8 - L9 were not covered by tests

import torch

Check warning on line 11 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L11

Added line #L11 was not covered by tests

from botorch.logging import logger

Check warning on line 13 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L13

Added line #L13 was not covered by tests

from botorch_community.models.blls import AbstractBLLModel
from torch.func import grad

Check warning on line 16 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L15-L16

Added lines #L15 - L16 were not covered by tests


class BLLMaxPosteriorSampling:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a SamplingStrategy interface here that that defines the API for this (and subclasses torch.nn.Module) but doesn't support not providing an X_cand. I assume this is why you didn't subclass from that?

I don't think it's necessary to use this if it would require changes to SamplingStrategy but maybe leave a comment here that we may want to consider doing that in the future.

def __init__(

Check warning on line 20 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L19-L20

Added lines #L19 - L20 were not covered by tests
self,
model: AbstractBLLModel,
num_restarts: int = 10,
bounds: torch.Tensor | None = None,
discrete_inputs: bool = False,
):
"""
Implements Maximum Posterior Sampling for Bayesian Linear Last (VBLL) models.

This class provides functionality to sample from the posterior distribution of a
BLL model, with optional optimization to refine the sampling process.

Args:
model: The VBLL model from which posterior samples are drawn. Must be an
instance of `AbstractBLLModel`.
num_restarts: Number of restarts for optimization-based sampling.
Defaults to 10.
bounds: Tensor of shape (2, num_inputs) specifying the lower and upper
bounds for sampling. If None, defaults to [(0, 1)] for each input
dimension.
discrete_inputs: If True, assumes the input space is discrete and will be
provided in __call__. Defaults to False.

Raises:
ValueError:
If the provided `model` is not an instance of `AbstractBLLModel`.

Notes:
- If `bounds` is not provided, the default range [0,1] is assumed for each
input dimension.
"""
if not isinstance(model, AbstractBLLModel):
raise ValueError(

Check warning on line 53 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L52-L53

Added lines #L52 - L53 were not covered by tests
f"Model must be an instance of AbstractBLLModel, is {type(model)}"
)

self.model = model
self.device = model.device
self.discrete_inputs = discrete_inputs
self.num_restarts = num_restarts

Check warning on line 60 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L57-L60

Added lines #L57 - L60 were not covered by tests

if bounds is None:

Check warning on line 62 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L62

Added line #L62 was not covered by tests
# Default bounds [0,1] for each input dimension
self.bounds = [(0, 1)] * self.model.num_inputs
self.lb = torch.zeros(

Check warning on line 65 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L64-L65

Added lines #L64 - L65 were not covered by tests
self.model.num_inputs, dtype=torch.float64, device=torch.device("cpu")
)
self.ub = torch.ones(

Check warning on line 68 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L68

Added line #L68 was not covered by tests
self.model.num_inputs, dtype=torch.float64, device=torch.device("cpu")
)
else:
# Ensure bounds are on CPU for compatibility with scipy.optimize.minimize
self.lb = bounds[0, :].cpu()
self.ub = bounds[1, :].cpu()
self.bounds = [tuple(bound) for bound in bounds.T.cpu().tolist()]

Check warning on line 75 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L73-L75

Added lines #L73 - L75 were not covered by tests

def __call__(

Check warning on line 77 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L77

Added line #L77 was not covered by tests
self, X_cand: torch.Tensor = None, num_samples: int = 1
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a docstring describing this method, including the expected input and return shapes of the tensors?

if self.discrete_inputs and X_cand is None:
raise ValueError("X_cand must be provided if `discrete_inputs` is True.")

Check warning on line 81 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L80-L81

Added lines #L80 - L81 were not covered by tests

if X_cand is not None and not self.discrete_inputs:
raise ValueError("X_cand is provided but `discrete_inputs` is False.")

Check warning on line 84 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L83-L84

Added lines #L83 - L84 were not covered by tests

X_next = torch.empty(

Check warning on line 86 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L86

Added line #L86 was not covered by tests
num_samples, self.model.num_inputs, dtype=torch.float64, device=self.device
)

# get max of sampled functions at candidate points for each function
for i in range(num_samples):
f = self.model.sample()

Check warning on line 92 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L91-L92

Added lines #L91 - L92 were not covered by tests

if self.discrete_inputs:

Check warning on line 94 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L94

Added line #L94 was not covered by tests
# evaluate sample path at candidate points and select best
Y_cand = f(X_cand)

Check warning on line 96 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L96

Added line #L96 was not covered by tests
else:
# optimize sample path
X_cand, Y_cand = _optimize_sample_path(

Check warning on line 99 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L99

Added line #L99 was not covered by tests
f=f,
num_restarts=self.num_restarts,
bounds=self.bounds,
lb=self.lb,
ub=self.ub,
device=self.device,
)

# select the best candidate
X_next[i, :] = X_cand[Y_cand.argmax()]

Check warning on line 109 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L109

Added line #L109 was not covered by tests

# ensure that the next point is within the bounds,
# scipy minimize can sometimes return points outside the bounds
X_next = torch.clamp(X_next, self.lb.to(self.device), self.ub.to(self.device))
return X_next

Check warning on line 114 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L113-L114

Added lines #L113 - L114 were not covered by tests


def _optimize_sample_path(

Check warning on line 117 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L117

Added line #L117 was not covered by tests
f: torch.nn.Module,
num_restarts: int,
bounds: list[tuple[float, float]],
lb: torch.Tensor,
ub: torch.Tensor,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Helper function to optimize the sample path of a BLL network.

Args:
f: The sample to optimize.
num_restarts: Number of restarts for optimization-based sampling.
bounds: List of tuples specifying the lower and upper bounds for each input
dimension.
lb: Lower bounds for each input dimension.
ub: Upper bounds for each input dimension.
device: Device on which to store the candidate points.

Returns:
Candidate points and corresponding function values.
"""
X_cand = torch.empty(num_restarts, f.num_inputs, dtype=torch.float64, device=device)
Y_cand = torch.empty(

Check warning on line 140 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L139-L140

Added lines #L139 - L140 were not covered by tests
num_restarts, f.num_outputs, dtype=torch.float64, device=device
)

# create numpy wrapper around the sampled function, note we aim to maximize
def func(x):
return -f(torch.from_numpy(x).to(device)).detach().cpu().numpy()

Check warning on line 146 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L145-L146

Added lines #L145 - L146 were not covered by tests

# get gradient and create wrapper
grad_f = grad(lambda x: f(x).mean())

Check warning on line 149 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L149

Added line #L149 was not covered by tests

def grad_func(x):
return -grad_f(torch.from_numpy(x).to(device)).detach().cpu().numpy()

Check warning on line 152 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L151-L152

Added lines #L151 - L152 were not covered by tests

# generate random initial conditions
x0s = np.random.rand(num_restarts, f.num_inputs)

Check warning on line 155 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L155

Added line #L155 was not covered by tests

for j in range(num_restarts):

Check warning on line 157 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L157

Added line #L157 was not covered by tests
# map to bounds
x0 = lb + (ub - lb) * x0s[j]

Check warning on line 159 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L159

Added line #L159 was not covered by tests

# optimize sample path
res = scipy.optimize.minimize(

Check warning on line 162 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L162

Added line #L162 was not covered by tests
func, x0, jac=grad_func, bounds=bounds, method="L-BFGS-B"
)

if not res.success:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work reliably? I've run into cases where the line search fails in odd ways sometimes, if there is a chance of this happening in any of the samples and we use a lot of samples then this could cause frequent failures. If this doesn't happen then great, o/w we may want to handle this, possibly by retrying the optimization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, actually very reliably, so I have not had any problems with it. Should we still add the additional "safety layer"? Though I suspect that if this fails, something in the model is off. May be nice to include this pointer in an error maybe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By this pointer you mean my comment above? I guess right now you're not erroring out but just warning, so that should be fine. Though if the majority of of the num_restarts have an optimization failure that's probably not good.

logger.warning(f"Optimization failed with message: {res.message}")

Check warning on line 167 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L166-L167

Added lines #L166 - L167 were not covered by tests

# store the candidate
X_cand[j, :] = torch.from_numpy(res.x).to(dtype=torch.float64)
Y_cand[j] = torch.tensor([-res.fun], dtype=torch.float64)

Check warning on line 171 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L170-L171

Added lines #L170 - L171 were not covered by tests

return X_cand, Y_cand

Check warning on line 173 in botorch_community/acquisition/bll_thompson_sampling.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/acquisition/bll_thompson_sampling.py#L173

Added line #L173 was not covered by tests
45 changes: 45 additions & 0 deletions botorch_community/models/blls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from abc import ABC, abstractmethod

import torch
import torch.nn as nn
from botorch.models.model import Model
from botorch_community.models.vbll_helper import DenseNormal, Normal
from torch import Tensor


class AbstractBLLModel(Model, ABC):
def __init__(self):
"""Abstract class for Bayesian Last Layer (BLL) models."""
super().__init__()
self.model = None

@property
def num_outputs(self) -> int:
return self.model.num_outputs

@property
def num_inputs(self):
return self.model.num_inputs

@property
def device(self):
return self.model.device

Check warning on line 33 in botorch_community/models/blls.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/models/blls.py#L33

Added line #L33 was not covered by tests

@abstractmethod
def __call__(self, X: Tensor) -> Normal | DenseNormal:
raise NotImplementedError

Check warning on line 37 in botorch_community/models/blls.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/models/blls.py#L37

Added line #L37 was not covered by tests

@abstractmethod
def fit(self, *args, **kwargs):
raise NotImplementedError

Check warning on line 41 in botorch_community/models/blls.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/models/blls.py#L41

Added line #L41 was not covered by tests

@abstractmethod
def sample(self, sample_shape: torch.Size | None = None) -> nn.Module:
raise NotImplementedError

Check warning on line 45 in botorch_community/models/blls.py

View check run for this annotation

Codecov / codecov/patch

botorch_community/models/blls.py#L45

Added line #L45 was not covered by tests
Loading
Loading