Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions botorch_community/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from botorch_community.acquisition.bayesian_active_learning import (
qBayesianQueryByComittee,
qBayesianVarianceReduction,
qHyperparameterInformedPredictiveExploration,
qStatisticalDistanceActiveLearning,
)

Expand All @@ -23,6 +24,7 @@
"LogRegionalExpectedImprovement",
"qBayesianQueryByComittee",
"qBayesianVarianceReduction",
"qHyperparameterInformedPredictiveExploration",
"qLogRegionalExpectedImprovement",
"qSelfCorrectingBayesianOptimization",
"qStatisticalDistanceActiveLearning",
Expand Down
197 changes: 196 additions & 1 deletion botorch_community/acquisition/bayesian_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,21 @@

from __future__ import annotations

from typing import Optional
import math

from typing import Literal, Optional

import torch
from botorch.acquisition.acquisition import MCSamplerMixin
from botorch.acquisition.bayesian_active_learning import (
FullyBayesianAcquisitionFunction,
qBayesianActiveLearningByDisagreement,
)
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.fully_bayesian import MCMC_DIM, SaasFullyBayesianSingleTaskGP
from botorch.optim import optimize_acqf
from botorch.sampling.base import MCSampler
from botorch.utils.sampling import draw_sobol_samples
from botorch.utils.transforms import (
average_over_ensemble_models,
concatenate_pending_points,
Expand All @@ -51,6 +59,7 @@


SAMPLE_DIM = -4
TWO_PI_E = 2 * math.pi * math.e
DISTANCE_METRICS = {
"hellinger": mvn_hellinger_distance,
"kl_divergence": mvn_kl_divergence,
Expand Down Expand Up @@ -159,3 +168,189 @@ def forward(self, X: Tensor) -> Tensor:
# squeeze output dim - batch dim computed and reduced inside of dist
# MCMC dim is averaged in decorator
return dist.squeeze(-1)


class qExpectedPredictiveInformationGain(FullyBayesianAcquisitionFunction):
def __init__(
self,
model: SaasFullyBayesianSingleTaskGP,
mc_points: Tensor,
X_pending: Tensor | None = None,
) -> None:
"""Expected predictive information gain for active learning.

Computes the mutual information between candidate queries and a test set
(typically MC samples over the design space).

Args:
model: A fully bayesian model (SaasFullyBayesianSingleTaskGP).
mc_points: A `N x d` tensor of points to use for MC-integrating the
posterior entropy (test set).
X_pending: A `m x d`-dim Tensor of `m` design points.
"""
super().__init__(model)
if mc_points.ndim != 2:
raise ValueError(
f"mc_points must be a 2-dimensional tensor, but got shape "
f"{mc_points.shape}"
)
self.register_buffer("mc_points", mc_points)
self.set_X_pending(X_pending)

@concatenate_pending_points
@t_batch_mode_transform()
@average_over_ensemble_models
def forward(self, X: Tensor) -> Tensor:
"""Evaluate test set information gain.

Args:
X: A `batch_shape x q x d`-dim Tensor of input points.

Returns:
A Tensor of information gain values.
"""
# Get the posterior for the candidate points
posterior = self.model.posterior(X, observation_noise=True)
noise = (
posterior.variance
- self.model.posterior(X, observation_noise=False).variance
)
cond_Y = posterior.mean

# Condition the model on the candidate observations
cond_X = X.unsqueeze(-3).expand(*[cond_Y.shape[:-1] + X.shape[-1:]])
conditional_model = self.model.condition_on_observations(
X=cond_X,
Y=cond_Y,
noise=noise,
)

# Evaluate posterior variance at test set with and without conditioning
uncond_var = self.model.posterior(
self.mc_points, observation_noise=True
).variance
cond_var = conditional_model.posterior(
self.mc_points, observation_noise=True
).variance

# Compute information gain as reduction in entropy
prev_entropy = torch.log(uncond_var * TWO_PI_E).sum(-1) / 2
post_entropy = torch.log(cond_var * TWO_PI_E).sum(-1) / 2
return (prev_entropy - post_entropy).mean(-1)


class qHyperparameterInformedPredictiveExploration(
FullyBayesianAcquisitionFunction, MCSamplerMixin
):
def __init__(
self,
model: SaasFullyBayesianSingleTaskGP,
mc_points: Tensor,
bounds: Tensor,
sampler: MCSampler | None = None,
posterior_transform: PosteriorTransform | None = None,
X_pending: Tensor | None = None,
num_samples: int = 512,
beta: float | None = None,
beta_tuning_method: Literal["sobol", "optimize"] = "sobol",
) -> None:
"""Hyperparameter-informed Predictive Exploration acquisition function.

This acquisition function combines the mutual information between the
subsequent queries and a test set (predictive information gain) with the
mutual information between observations and hyperparameters (BALD), weighted
by a tuning factor. This balances exploration of the design space with
reduction of hyperparameter uncertainty.

The acquisition function is computed as:
beta * BALD + TSIG
where beta is either provided or automatically tuned.

Args:
model: A fully bayesian model (SaasFullyBayesianSingleTaskGP).
mc_points: A `N x d` tensor of points to use for MC-integrating the
posterior entropy (test set). Usually, these are qMC samples on
the whole design space.
bounds: A `2 x d` tensor of bounds for the design space, used for
beta tuning.
sampler: The sampler used for drawing samples to approximate the entropy
of the Gaussian Mixture posterior. If None, uses default sampler.
X_pending: A `m x d`-dim Tensor of `m` design points that have been
submitted for evaluation but have not yet been observed.
num_samples: Number of samples to use for MC estimation of entropy.
beta: Fixed tuning factor. If None, it will be automatically computed
on the first forward pass based on the batch size q.
beta_tuning_method: Method for tuning beta. Options are "optimize"
(optimize acquisition function to find beta) or "sobol" (use sobol
samples). Only used when beta is None.
"""
super().__init__(model=model)
MCSamplerMixin.__init__(self)
if mc_points.ndim != 2:
raise ValueError(
f"mc_points must be a 2-dimensional tensor, but got shape "
f"{mc_points.shape}"
)
self.set_X_pending(X_pending)
self.num_samples = num_samples
self.beta_tuning_method = beta_tuning_method
self.register_buffer("mc_points", mc_points)
self.register_buffer("bounds", bounds)
self.sampler = sampler
self.posterior_transform = posterior_transform
self._tuning_factor: float | None = beta
self._tuning_factor_q: int | None = None

def _compute_tuning_factor(self, q: int) -> None:
"""Compute the tuning factor beta for weighting BALD vs TSIG."""
if self.beta_tuning_method == "sobol":
draws = draw_sobol_samples(
bounds=self.bounds,
q=q,
n=1,
).squeeze(0)
# Compute the ratio at sobol samples
tsig_val = qExpectedPredictiveInformationGain.forward(
self,
draws,
)
bald_val = qBayesianActiveLearningByDisagreement.forward(self, draws)
self._tuning_factor = (tsig_val / (bald_val + 1e-8)).mean().item()
elif self.beta_tuning_method == "optimize":
# Optimize to find the best tuning factor
bald_acqf = qBayesianActiveLearningByDisagreement(
model=self.model,
sampler=self.sampler,
)
_, bald_val = optimize_acqf(
bald_acqf,
bounds=self.bounds,
q=q,
num_restarts=1,
raw_samples=128,
options={"batch_limit": 16},
)
self._tuning_factor = bald_val.mean().item()
self._tuning_factor_q = q

@concatenate_pending_points
@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
"""Evaluate the acquisition function at X.

Args:
X: A `batch_shape x q x d`-dim Tensor of input points.

Returns:
A `batch_shape`-dim Tensor of acquisition values.
"""
q = X.shape[-2]
# Compute tuning factor if not set or if q has changed
if self._tuning_factor is None or self._tuning_factor_q != q:
self._compute_tuning_factor(q)

tsig = qExpectedPredictiveInformationGain.forward(self, X)
bald = qBayesianActiveLearningByDisagreement.forward(self, X)
# Since both acquisition functions are averaged over the ensemble,
# we do not average over the ensemble again here.
return self._tuning_factor * bald + tsig
107 changes: 104 additions & 3 deletions test_community/acquisition/test_bayesian_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
from itertools import product

import torch
from botorch.utils.sampling import draw_sobol_samples
from botorch.utils.test_helpers import get_fully_bayesian_model
from botorch.utils.testing import BotorchTestCase
from botorch_community.acquisition.bayesian_active_learning import (
qBayesianQueryByComittee,
qBayesianVarianceReduction,
qExpectedPredictiveInformationGain,
qHyperparameterInformedPredictiveExploration,
qStatisticalDistanceActiveLearning,
)

Expand Down Expand Up @@ -72,14 +75,112 @@ def test_q_statistical_distance_active_learning(self):
# assess shape
self.assertTrue(acq_X.shape == test_Xs[j].shape[:-2])


class TestQExpectedPredictiveInformationGain(BotorchTestCase):
def test_q_expected_predictive_information_gain(self):
torch.manual_seed(1)
tkwargs = {"device": self.device, "dtype": torch.double}
input_dim = 2

model = get_fully_bayesian_model(
train_X=torch.rand(4, input_dim, **tkwargs),
train_Y=torch.rand(4, 1, **tkwargs),
num_models=3,
**tkwargs,
)
bounds = torch.tensor([[0.0] * input_dim, [1.0] * input_dim], **tkwargs)
mc_points = draw_sobol_samples(bounds=bounds, n=16, q=1).squeeze(-2)

acq = qExpectedPredictiveInformationGain(model=model, mc_points=mc_points)
test_X = torch.rand(4, 2, input_dim, **tkwargs)
acq_X = acq(test_X)
self.assertEqual(acq_X.shape, test_X.shape[:-2])
self.assertTrue((acq_X >= 0).all())

# test that mc_points must be 2-dimensional
with self.assertRaises(ValueError):
acq = qStatisticalDistanceActiveLearning(
qExpectedPredictiveInformationGain(
model=model,
distance_metric="NOT_A_DISTANCE",
X_pending=X_pending,
mc_points=mc_points.unsqueeze(0), # 3D tensor
)


class TestQHyperparameterInformedPredictiveExploration(BotorchTestCase):
def test_q_hyperparameter_informed_predictive_exploration(self):
torch.manual_seed(1)
tkwargs = {"device": self.device}
num_objectives = 1
num_models = 3
for (
dtype,
standardize_model,
infer_noise,
) in product(
(torch.float, torch.double),
(False, True),
(True,),
):
tkwargs["dtype"] = dtype
input_dim = 2
train_X = torch.rand(4, input_dim, **tkwargs)
train_Y = torch.rand(4, num_objectives, **tkwargs)

model = get_fully_bayesian_model(
train_X=train_X,
train_Y=train_Y,
num_models=num_models,
standardize_model=standardize_model,
infer_noise=infer_noise,
**tkwargs,
)

bounds = torch.tensor([[0.0] * input_dim, [1.0] * input_dim], **tkwargs)
mc_points = draw_sobol_samples(bounds=bounds, n=16, q=1).squeeze(-2)

# test with fixed beta
acq = qHyperparameterInformedPredictiveExploration(
model=model,
mc_points=mc_points,
bounds=bounds,
beta=1.0,
)

test_Xs = [
torch.rand(4, 1, input_dim, **tkwargs),
torch.rand(4, 3, input_dim, **tkwargs),
]

for test_X in test_Xs:
acq_X = acq(test_X)
# assess shape
self.assertTrue(acq_X.shape == test_X.shape[:-2])
self.assertTrue((acq_X > 0).all())

# test beta tuning (beta=None) and re-tuning when q changes
for beta_tuning_method in ("sobol", "optimize"):
acq_tuned = qHyperparameterInformedPredictiveExploration(
model=model,
mc_points=mc_points,
bounds=bounds,
beta_tuning_method=beta_tuning_method,
)
# first forward pass computes tuning factor for q=1
acq_tuned(torch.rand(4, 1, input_dim, **tkwargs))
tuning_factor_q1 = acq_tuned._tuning_factor
# second forward pass with different q recomputes tuning factor
acq_tuned(torch.rand(4, 3, input_dim, **tkwargs))
tuning_factor_q3 = acq_tuned._tuning_factor
self.assertNotEqual(tuning_factor_q1, tuning_factor_q3)

# test that mc_points must be 2-dimensional
with self.assertRaises(ValueError):
qHyperparameterInformedPredictiveExploration(
model=model,
mc_points=mc_points.unsqueeze(0), # 3D tensor
bounds=bounds,
)


class TestQBayesianQueryByComittee(BotorchTestCase):
def test_q_bayesian_query_by_comittee(self):
torch.manual_seed(1)
Expand Down