|
33 | 33 |
|
34 | 34 | from __future__ import annotations |
35 | 35 |
|
36 | | -from typing import Optional |
| 36 | +import math |
| 37 | + |
| 38 | +from typing import Literal, Optional |
37 | 39 |
|
38 | 40 | import torch |
| 41 | +from botorch.acquisition.acquisition import MCSamplerMixin |
39 | 42 | from botorch.acquisition.bayesian_active_learning import ( |
40 | 43 | FullyBayesianAcquisitionFunction, |
| 44 | + qBayesianActiveLearningByDisagreement, |
41 | 45 | ) |
| 46 | +from botorch.acquisition.objective import PosteriorTransform |
42 | 47 | from botorch.models.fully_bayesian import MCMC_DIM, SaasFullyBayesianSingleTaskGP |
| 48 | +from botorch.optim import optimize_acqf |
| 49 | +from botorch.sampling.base import MCSampler |
| 50 | +from botorch.utils.sampling import draw_sobol_samples |
43 | 51 | from botorch.utils.transforms import ( |
44 | 52 | average_over_ensemble_models, |
45 | 53 | concatenate_pending_points, |
|
51 | 59 |
|
52 | 60 |
|
53 | 61 | SAMPLE_DIM = -4 |
| 62 | +TWO_PI_E = 2 * math.pi * math.e |
54 | 63 | DISTANCE_METRICS = { |
55 | 64 | "hellinger": mvn_hellinger_distance, |
56 | 65 | "kl_divergence": mvn_kl_divergence, |
@@ -159,3 +168,189 @@ def forward(self, X: Tensor) -> Tensor: |
159 | 168 | # squeeze output dim - batch dim computed and reduced inside of dist |
160 | 169 | # MCMC dim is averaged in decorator |
161 | 170 | return dist.squeeze(-1) |
| 171 | + |
| 172 | + |
| 173 | +class qExpectedPredictiveInformationGain(FullyBayesianAcquisitionFunction): |
| 174 | + def __init__( |
| 175 | + self, |
| 176 | + model: SaasFullyBayesianSingleTaskGP, |
| 177 | + mc_points: Tensor, |
| 178 | + X_pending: Tensor | None = None, |
| 179 | + ) -> None: |
| 180 | + """Expected predictive information gain for active learning. |
| 181 | +
|
| 182 | + Computes the mutual information between candidate queries and a test set |
| 183 | + (typically MC samples over the design space). |
| 184 | +
|
| 185 | + Args: |
| 186 | + model: A fully bayesian model (SaasFullyBayesianSingleTaskGP). |
| 187 | + mc_points: A `N x d` tensor of points to use for MC-integrating the |
| 188 | + posterior entropy (test set). |
| 189 | + X_pending: A `m x d`-dim Tensor of `m` design points. |
| 190 | + """ |
| 191 | + super().__init__(model) |
| 192 | + if mc_points.ndim != 2: |
| 193 | + raise ValueError( |
| 194 | + f"mc_points must be a 2-dimensional tensor, but got shape " |
| 195 | + f"{mc_points.shape}" |
| 196 | + ) |
| 197 | + self.register_buffer("mc_points", mc_points) |
| 198 | + self.set_X_pending(X_pending) |
| 199 | + |
| 200 | + @concatenate_pending_points |
| 201 | + @t_batch_mode_transform() |
| 202 | + @average_over_ensemble_models |
| 203 | + def forward(self, X: Tensor) -> Tensor: |
| 204 | + """Evaluate test set information gain. |
| 205 | +
|
| 206 | + Args: |
| 207 | + X: A `batch_shape x q x d`-dim Tensor of input points. |
| 208 | +
|
| 209 | + Returns: |
| 210 | + A Tensor of information gain values. |
| 211 | + """ |
| 212 | + # Get the posterior for the candidate points |
| 213 | + posterior = self.model.posterior(X, observation_noise=True) |
| 214 | + noise = ( |
| 215 | + posterior.variance |
| 216 | + - self.model.posterior(X, observation_noise=False).variance |
| 217 | + ) |
| 218 | + cond_Y = posterior.mean |
| 219 | + |
| 220 | + # Condition the model on the candidate observations |
| 221 | + cond_X = X.unsqueeze(-3).expand(*[cond_Y.shape[:-1] + X.shape[-1:]]) |
| 222 | + conditional_model = self.model.condition_on_observations( |
| 223 | + X=cond_X, |
| 224 | + Y=cond_Y, |
| 225 | + noise=noise, |
| 226 | + ) |
| 227 | + |
| 228 | + # Evaluate posterior variance at test set with and without conditioning |
| 229 | + uncond_var = self.model.posterior( |
| 230 | + self.mc_points, observation_noise=True |
| 231 | + ).variance |
| 232 | + cond_var = conditional_model.posterior( |
| 233 | + self.mc_points, observation_noise=True |
| 234 | + ).variance |
| 235 | + |
| 236 | + # Compute information gain as reduction in entropy |
| 237 | + prev_entropy = torch.log(uncond_var * TWO_PI_E).sum(-1) / 2 |
| 238 | + post_entropy = torch.log(cond_var * TWO_PI_E).sum(-1) / 2 |
| 239 | + return (prev_entropy - post_entropy).mean(-1) |
| 240 | + |
| 241 | + |
| 242 | +class qHyperparameterInformedPredictiveExploration( |
| 243 | + FullyBayesianAcquisitionFunction, MCSamplerMixin |
| 244 | +): |
| 245 | + def __init__( |
| 246 | + self, |
| 247 | + model: SaasFullyBayesianSingleTaskGP, |
| 248 | + mc_points: Tensor, |
| 249 | + bounds: Tensor, |
| 250 | + sampler: MCSampler | None = None, |
| 251 | + posterior_transform: PosteriorTransform | None = None, |
| 252 | + X_pending: Tensor | None = None, |
| 253 | + num_samples: int = 512, |
| 254 | + beta: float | None = None, |
| 255 | + beta_tuning_method: Literal["sobol", "optimize"] = "sobol", |
| 256 | + ) -> None: |
| 257 | + """Hyperparameter-informed Predictive Exploration acquisition function. |
| 258 | +
|
| 259 | + This acquisition function combines the mutual information between the |
| 260 | + subsequent queries and a test set (predictive information gain) with the |
| 261 | + mutual information between observations and hyperparameters (BALD), weighted |
| 262 | + by a tuning factor. This balances exploration of the design space with |
| 263 | + reduction of hyperparameter uncertainty. |
| 264 | +
|
| 265 | + The acquisition function is computed as: |
| 266 | + beta * BALD + TSIG |
| 267 | + where beta is either provided or automatically tuned. |
| 268 | +
|
| 269 | + Args: |
| 270 | + model: A fully bayesian model (SaasFullyBayesianSingleTaskGP). |
| 271 | + mc_points: A `N x d` tensor of points to use for MC-integrating the |
| 272 | + posterior entropy (test set). Usually, these are qMC samples on |
| 273 | + the whole design space. |
| 274 | + bounds: A `2 x d` tensor of bounds for the design space, used for |
| 275 | + beta tuning. |
| 276 | + sampler: The sampler used for drawing samples to approximate the entropy |
| 277 | + of the Gaussian Mixture posterior. If None, uses default sampler. |
| 278 | + X_pending: A `m x d`-dim Tensor of `m` design points that have been |
| 279 | + submitted for evaluation but have not yet been observed. |
| 280 | + num_samples: Number of samples to use for MC estimation of entropy. |
| 281 | + beta: Fixed tuning factor. If None, it will be automatically computed |
| 282 | + on the first forward pass based on the batch size q. |
| 283 | + beta_tuning_method: Method for tuning beta. Options are "optimize" |
| 284 | + (optimize acquisition function to find beta) or "sobol" (use sobol |
| 285 | + samples). Only used when beta is None. |
| 286 | + """ |
| 287 | + super().__init__(model=model) |
| 288 | + MCSamplerMixin.__init__(self) |
| 289 | + if mc_points.ndim != 2: |
| 290 | + raise ValueError( |
| 291 | + f"mc_points must be a 2-dimensional tensor, but got shape " |
| 292 | + f"{mc_points.shape}" |
| 293 | + ) |
| 294 | + self.set_X_pending(X_pending) |
| 295 | + self.num_samples = num_samples |
| 296 | + self.beta_tuning_method = beta_tuning_method |
| 297 | + self.register_buffer("mc_points", mc_points) |
| 298 | + self.register_buffer("bounds", bounds) |
| 299 | + self.sampler = sampler |
| 300 | + self.posterior_transform = posterior_transform |
| 301 | + self._tuning_factor: float | None = beta |
| 302 | + self._tuning_factor_q: int | None = None |
| 303 | + |
| 304 | + def _compute_tuning_factor(self, q: int) -> None: |
| 305 | + """Compute the tuning factor beta for weighting BALD vs TSIG.""" |
| 306 | + if self.beta_tuning_method == "sobol": |
| 307 | + draws = draw_sobol_samples( |
| 308 | + bounds=self.bounds, |
| 309 | + q=q, |
| 310 | + n=1, |
| 311 | + ).squeeze(0) |
| 312 | + # Compute the ratio at sobol samples |
| 313 | + tsig_val = qExpectedPredictiveInformationGain.forward( |
| 314 | + self, |
| 315 | + draws, |
| 316 | + ) |
| 317 | + bald_val = qBayesianActiveLearningByDisagreement.forward(self, draws) |
| 318 | + self._tuning_factor = (tsig_val / (bald_val + 1e-8)).mean().item() |
| 319 | + elif self.beta_tuning_method == "optimize": |
| 320 | + # Optimize to find the best tuning factor |
| 321 | + bald_acqf = qBayesianActiveLearningByDisagreement( |
| 322 | + model=self.model, |
| 323 | + sampler=self.sampler, |
| 324 | + ) |
| 325 | + _, bald_val = optimize_acqf( |
| 326 | + bald_acqf, |
| 327 | + bounds=self.bounds, |
| 328 | + q=q, |
| 329 | + num_restarts=1, |
| 330 | + raw_samples=128, |
| 331 | + options={"batch_limit": 16}, |
| 332 | + ) |
| 333 | + self._tuning_factor = bald_val.mean().item() |
| 334 | + self._tuning_factor_q = q |
| 335 | + |
| 336 | + @concatenate_pending_points |
| 337 | + @t_batch_mode_transform() |
| 338 | + def forward(self, X: Tensor) -> Tensor: |
| 339 | + """Evaluate the acquisition function at X. |
| 340 | +
|
| 341 | + Args: |
| 342 | + X: A `batch_shape x q x d`-dim Tensor of input points. |
| 343 | +
|
| 344 | + Returns: |
| 345 | + A `batch_shape`-dim Tensor of acquisition values. |
| 346 | + """ |
| 347 | + q = X.shape[-2] |
| 348 | + # Compute tuning factor if not set or if q has changed |
| 349 | + if self._tuning_factor is None or self._tuning_factor_q != q: |
| 350 | + self._compute_tuning_factor(q) |
| 351 | + |
| 352 | + tsig = qExpectedPredictiveInformationGain.forward(self, X) |
| 353 | + bald = qBayesianActiveLearningByDisagreement.forward(self, X) |
| 354 | + # Since both acquisition functions are averaged over the ensemble, |
| 355 | + # we do not average over the ensemble again here. |
| 356 | + return self._tuning_factor * bald + tsig |
0 commit comments