Skip to content

Commit e44280e

Browse files
Carl Hvarfnerfacebook-github-bot
Carl Hvarfner
authored andcommitted
Improvement of qBayesianActiveLearningByDisagreement (#2457)
Summary: Pull Request resolved: #2457 Improvement of the implementation of qBayesianActiveLearningByDisagreement - Utilizes a Monte Carlo approach for approximating the entropy - Does not use concatenate_pending_points, as it is not evident that fantasizing makes sense in the same way as for standard MC acquisition functions - Can accept posterior transforms - get_model and get_fully_bayesian_model are used in tests to be similar to other tests (e.g. JES & the subsequent active learning acqfs to enable move to test_helpers Reviewed By: saitcakmak Differential Revision: D60308502 fbshipit-source-id: 6de1dffc4f497ef4823428b2903b19ff8f0d60d7
1 parent 9ddd9eb commit e44280e

File tree

3 files changed

+117
-52
lines changed

3 files changed

+117
-52
lines changed

botorch/acquisition/bayesian_active_learning.py

+51-26
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222

2323
from typing import Optional
2424

25-
import torch
2625
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
27-
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
26+
from botorch.acquisition.objective import PosteriorTransform
27+
from botorch.models.fully_bayesian import MCMC_DIM, SaasFullyBayesianSingleTaskGP
2828
from botorch.models.model import Model
29+
from botorch.sampling.base import MCSampler
2930
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
3031
from torch import Tensor
3132

@@ -54,48 +55,72 @@ class qBayesianActiveLearningByDisagreement(
5455
def __init__(
5556
self,
5657
model: SaasFullyBayesianSingleTaskGP,
58+
sampler: Optional[MCSampler] = None,
59+
posterior_transform: Optional[PosteriorTransform] = None,
5760
X_pending: Optional[Tensor] = None,
5861
) -> None:
5962
"""
6063
Batch implementation [kirsch2019batchbald]_ of BALD [Houlsby2011bald]_,
6164
which maximizes the mutual information between the next observation and the
62-
hyperparameters of the model. Computed by informational lower bound.
65+
hyperparameters of the model. Computed by Monte Carlo integration.
6366
6467
Args:
65-
model: A fully bayesian single-outcome model.
66-
X_pending: A `batch_shape, m x d`-dim Tensor of `m` design points.
68+
model: A fully bayesian model (SaasFullyBayesianSingleTaskGP).
69+
sampler: The sampler used for drawing samples to approximate the entropy
70+
of the Gaussian Mixture posterior.
71+
posterior_transform: A PosteriorTransform. If using a multi-output model,
72+
a PosteriorTransform that transforms the multi-output posterior into a
73+
single-output posterior is required.
74+
X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points
75+
6776
"""
68-
super().__init__(model)
77+
super().__init__(model=model)
78+
MCSamplerMixin.__init__(self, sampler=sampler)
6979
self.set_X_pending(X_pending)
80+
self.posterior_transform = posterior_transform
7081

7182
@concatenate_pending_points
7283
@t_batch_mode_transform()
7384
def forward(self, X: Tensor) -> Tensor:
7485
r"""Evaluate qBayesianActiveLearningByDisagreement on the candidate set `X`.
86+
A monte carlo-estimated information gain is computed over a Gaussian Mixture
87+
marginal posterior, and the Gaussian conditional posterior to obtain the
88+
qBayesianActiveLearningByDisagreement on the candidate set `X`.
7589
7690
Args:
7791
X: `batch_shape x q x D`-dim Tensor of input points.
7892
7993
Returns:
8094
A `batch_shape x num_models`-dim Tensor of BALD values.
8195
"""
82-
return self._compute_lower_bound_information_gain(X)
83-
84-
def _compute_lower_bound_information_gain(self, X: Tensor) -> Tensor:
85-
r"""Evaluates the lower bounded information gain on the candidate set `X`.
86-
87-
Args:
88-
X: `batch_shape x q x D`-dim Tensor of input points.
89-
90-
Returns:
91-
A `batch_shape x num_models`-dim Tensor of information gains.
92-
"""
93-
posterior = self.model.posterior(X, observation_noise=True)
94-
marg_covar = posterior.mixture_covariance_matrix
95-
cond_variances = posterior.variance
96-
97-
prev_entropy = torch.logdet(marg_covar).unsqueeze(-1)
98-
# squeeze excess dim and mean over q-batch
99-
post_ub_entropy = torch.log(cond_variances).squeeze(-1).mean(-1)
100-
101-
return prev_entropy - post_ub_entropy
96+
posterior = self.model.posterior(
97+
X, observation_noise=True, posterior_transform=self.posterior_transform
98+
)
99+
# draw samples from the mixture posterior.
100+
# samples: num_samples x batch_shape x num_models x q x num_outputs
101+
samples = self.get_posterior_samples(posterior=posterior)
102+
103+
# Estimate the entropy of 'num_samples' samples from 'num_models' models by
104+
# evaluating the log_prob on each sample on the mixture posterior
105+
# (which constitutes of M models). thus, order N*M^2 computations
106+
107+
# Make room and move the model dim to the front, squeeze the num_outputs dim.
108+
# prev_samples: num_models x num_samples x batch_shape x 1 x q
109+
prev_samples = samples.unsqueeze(0).transpose(0, MCMC_DIM).squeeze(-1)
110+
111+
# avg the probs over models in the mixture - dim (-2) will be broadcasted
112+
# with the num_models of the posterior --> querying all samples on all models
113+
# posterior.mvn takes q-dimensional input by default, which removes the q-dim
114+
# component_sample_probs: num_models x num_samples x batch_shape x num_models
115+
component_sample_probs = posterior.mvn.log_prob(prev_samples).exp()
116+
117+
# average over mixture components
118+
mixture_sample_probs = component_sample_probs.mean(dim=-1)
119+
120+
# this is the average over the model and sample dim
121+
prev_entropy = -mixture_sample_probs.log().mean(dim=[0, 1])
122+
123+
# the posterior entropy is an average entropy over gaussians, so no mixture
124+
post_entropy = -posterior.mvn.log_prob(samples.squeeze(-1)).mean(0)
125+
bald = prev_entropy.unsqueeze(-1) - post_entropy
126+
return bald

botorch/acquisition/input_constructors.py

+4
Original file line numberDiff line numberDiff line change
@@ -1678,9 +1678,13 @@ def construct_inputs_qJES(
16781678
def construct_inputs_BALD(
16791679
model: Model,
16801680
X_pending: Optional[Tensor] = None,
1681+
sampler: Optional[MCSampler] = None,
1682+
posterior_transform: Optional[PosteriorTransform] = None,
16811683
):
16821684
inputs = {
16831685
"model": model,
16841686
"X_pending": X_pending,
1687+
"sampler": sampler,
1688+
"posterior_transform": posterior_transform,
16851689
}
16861690
return inputs

test/acquisition/test_bayesian_active_learning.py

+62-26
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,32 @@
1313
from botorch.models import SingleTaskGP
1414
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
1515
from botorch.models.transforms.outcome import Standardize
16+
from botorch.sampling.normal import IIDNormalSampler
1617
from botorch.utils.testing import BotorchTestCase
1718

1819

20+
def get_model(
21+
train_X,
22+
train_Y,
23+
standardize_model,
24+
**tkwargs,
25+
):
26+
num_objectives = train_Y.shape[-1]
27+
28+
if standardize_model:
29+
outcome_transform = Standardize(m=num_objectives)
30+
else:
31+
outcome_transform = None
32+
33+
model = SingleTaskGP(
34+
train_X=train_X,
35+
train_Y=train_Y,
36+
outcome_transform=outcome_transform,
37+
)
38+
39+
return model
40+
41+
1942
def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs):
2043

2144
mcmc_samples = {
@@ -28,7 +51,7 @@ def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs):
2851
return mcmc_samples
2952

3053

31-
def get_model(
54+
def get_fully_bayesian_model(
3255
train_X,
3356
train_Y,
3457
num_models,
@@ -72,21 +95,26 @@ def test_q_bayesian_active_learning_by_disagreement(self):
7295
tkwargs = {"device": self.device}
7396
num_objectives = 1
7497
num_models = 3
98+
input_dim = 2
99+
100+
X_pending_list = [None, torch.rand(2, input_dim)]
75101
for (
76102
dtype,
77103
standardize_model,
78104
infer_noise,
105+
X_pending,
79106
) in product(
80107
(torch.float, torch.double),
81108
(False, True), # standardize_model
82109
(True,), # infer_noise - only one option avail in PyroModels
110+
X_pending_list,
83111
):
112+
X_pending = X_pending.to(**tkwargs) if X_pending is not None else None
84113
tkwargs["dtype"] = dtype
85-
input_dim = 2
86114
train_X = torch.rand(4, input_dim, **tkwargs)
87115
train_Y = torch.rand(4, num_objectives, **tkwargs)
88116

89-
model = get_model(
117+
model = get_fully_bayesian_model(
90118
train_X,
91119
train_Y,
92120
num_models,
@@ -96,32 +124,40 @@ def test_q_bayesian_active_learning_by_disagreement(self):
96124
)
97125

98126
# test acquisition
99-
X_pending_list = [None, torch.rand(2, input_dim, **tkwargs)]
100-
for i in range(len(X_pending_list)):
101-
X_pending = X_pending_list[i]
102-
103-
acq = qBayesianActiveLearningByDisagreement(
104-
model=model,
105-
X_pending=X_pending,
106-
)
107-
108-
test_Xs = [
109-
torch.rand(4, 1, input_dim, **tkwargs),
110-
torch.rand(4, 3, input_dim, **tkwargs),
111-
torch.rand(4, 5, 1, input_dim, **tkwargs),
112-
torch.rand(4, 5, 3, input_dim, **tkwargs),
113-
]
114-
115-
for j in range(len(test_Xs)):
116-
acq_X = acq.forward(test_Xs[j])
117-
acq_X = acq(test_Xs[j])
118-
# assess shape
119-
self.assertTrue(acq_X.shape == test_Xs[j].shape[:-2])
127+
acq = qBayesianActiveLearningByDisagreement(
128+
model=model,
129+
X_pending=X_pending,
130+
)
131+
132+
acq2 = qBayesianActiveLearningByDisagreement(
133+
model=model, sampler=IIDNormalSampler(torch.Size([9]))
134+
)
135+
self.assertIsInstance(acq2.sampler, IIDNormalSampler)
136+
137+
test_Xs = [
138+
torch.rand(4, 1, input_dim, **tkwargs),
139+
torch.rand(4, 3, input_dim, **tkwargs),
140+
torch.rand(4, 5, 1, input_dim, **tkwargs),
141+
torch.rand(4, 5, 3, input_dim, **tkwargs),
142+
torch.rand(5, 13, input_dim, **tkwargs),
143+
]
144+
145+
for j in range(len(test_Xs)):
146+
acq_X = acq.forward(test_Xs[j])
147+
acq_X = acq(test_Xs[j])
148+
# assess shape
149+
self.assertTrue(acq_X.shape == test_Xs[j].shape[:-2])
150+
151+
self.assertTrue(torch.all(acq_X > 0))
120152

121153
# Support with non-fully bayesian models is not possible. Thus, we
122154
# throw an error.
123-
non_fully_bayesian_model = SingleTaskGP(train_X, train_Y)
124-
with self.assertRaises(ValueError):
155+
non_fully_bayesian_model = get_model(train_X, train_Y, False)
156+
with self.assertRaisesRegex(
157+
ValueError,
158+
"Fully Bayesian acquisition functions require a "
159+
"SaasFullyBayesianSingleTaskGP to run.",
160+
):
125161
acq = qBayesianActiveLearningByDisagreement(
126162
model=non_fully_bayesian_model,
127163
)

0 commit comments

Comments
 (0)