|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# |
| 4 | +# This source code is licensed under the MIT license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from __future__ import annotations |
| 8 | + |
| 9 | +from typing import Any |
| 10 | + |
| 11 | +import torch |
| 12 | +from botorch.models.fully_bayesian import MCMC_DIM |
| 13 | +from botorch.models.gp_regression import SingleTaskGP |
| 14 | +from botorch.models.gpytorch import GPyTorchModel |
| 15 | +from botorch.models.transforms.input import InputTransform |
| 16 | +from botorch.models.transforms.outcome import OutcomeTransform |
| 17 | +from botorch.posteriors.fully_bayesian import GaussianMixturePosterior |
| 18 | +from botorch.utils.datasets import SupervisedDataset |
| 19 | +from gpytorch.constraints import Interval |
| 20 | +from gpytorch.distributions.multivariate_normal import MultivariateNormal |
| 21 | +from gpytorch.means.mean import Mean |
| 22 | +from gpytorch.priors import HalfCauchyPrior, Prior |
| 23 | +from linear_operator.operators import PsdSumLinearOperator |
| 24 | +from torch import Tensor |
| 25 | +from torch.nn import Module, Parameter |
| 26 | +from torch.nn.modules import ModuleDict |
| 27 | + |
| 28 | +r""" |
| 29 | +Target-aware GP models. |
| 30 | +
|
| 31 | +References |
| 32 | +
|
| 33 | +.. [Feng2025experimenting] |
| 34 | + Q. Feng, S. Daulton, B. Letham, M. Balandat, and E. Bakshy. |
| 35 | + Experimenting, Fast and Slow: Bayesian Optimization of Long-term Outcomes |
| 36 | + with Online Experiments. Proceedings of the 31st ACM SIGKDD Conference on |
| 37 | + Knowledge Discovery and Data Mining, 2025. |
| 38 | +""" |
| 39 | + |
| 40 | +WEIGHT_THRESHOLD = 0.01 |
| 41 | + |
| 42 | + |
| 43 | +class TargetAwareEnsembleGP(SingleTaskGP): |
| 44 | + r"""A target-aware ensemble GP that takes a set of GPs pre-trained on the |
| 45 | + auxiliary data to improve prediction accuracy of the target task. |
| 46 | +
|
| 47 | + The target outputs are modeled as weighted sum of an offset function and |
| 48 | + functions of each auxiliary sources. The offset function is unknown. |
| 49 | + The ensemble model jointly optimizes the kernel hyperparameters of the |
| 50 | + target task together with the weights. |
| 51 | +
|
| 52 | + This model is described in [Feng2025experimenting]_. |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + train_X: Tensor, |
| 58 | + train_Y: Tensor, |
| 59 | + base_model_dict: dict[str, GPyTorchModel], |
| 60 | + train_Yvar: Tensor | None = None, |
| 61 | + covar_module: Module | None = None, |
| 62 | + mean_module: Mean | None = None, |
| 63 | + outcome_transform: OutcomeTransform | None = None, |
| 64 | + input_transform: InputTransform | None = None, |
| 65 | + ensemble_weight_prior: Prior | None = None, |
| 66 | + ) -> None: |
| 67 | + r"""A dictionary kernel with a scale kernel. |
| 68 | +
|
| 69 | + Args: |
| 70 | + train_X: (n x d) X training data of target task. |
| 71 | + train_Y: (n x 1) Y training data of target task. |
| 72 | + train_Yvar: (n x 1) Noise variances of each training Y. |
| 73 | + base_model_dict: Dict of GP models that each corresponds to a model trained |
| 74 | + on an auxiliary dataset. Keys are the name of auxiliary dataset. |
| 75 | + covar_module: The module computing the covariance (Kernel) matrix for the |
| 76 | + target data. If omitted, use a `MaternKernel`. |
| 77 | + mean_module: The mean function to be used for the target data. If omitted, |
| 78 | + use a `ConstantMean`. |
| 79 | + ensemble_weight_prior: The prior over the weights of the ensemble model. |
| 80 | + If omitted, use default HalfCauchyPrior with scale = 1.0. |
| 81 | + """ |
| 82 | + if ensemble_weight_prior is None: |
| 83 | + ensemble_weight_prior = HalfCauchyPrior(scale=1.0) |
| 84 | + |
| 85 | + super().__init__( |
| 86 | + train_X=train_X, |
| 87 | + train_Y=train_Y, |
| 88 | + train_Yvar=train_Yvar, |
| 89 | + covar_module=covar_module, |
| 90 | + mean_module=mean_module, |
| 91 | + outcome_transform=outcome_transform, |
| 92 | + input_transform=input_transform, |
| 93 | + ) |
| 94 | + |
| 95 | + self.base_model_dict = ModuleDict(base_model_dict) |
| 96 | + self.base_model_dict.eval() |
| 97 | + |
| 98 | + # register ensemble weights |
| 99 | + self.register_parameter( |
| 100 | + name="raw_weight", |
| 101 | + parameter=Parameter( |
| 102 | + torch.zeros(len(self.base_model_dict), device=train_X.device) |
| 103 | + ), |
| 104 | + ) |
| 105 | + self.register_constraint( |
| 106 | + param_name="raw_weight", |
| 107 | + constraint=Interval(-10, 10, initial_value=0.1), |
| 108 | + replace=True, |
| 109 | + ) |
| 110 | + # set prior on weights so that the unimportant auxiliary |
| 111 | + # sources can be shrunk to 0. |
| 112 | + self.register_prior( |
| 113 | + "weight_prior", |
| 114 | + ensemble_weight_prior.to(train_X), |
| 115 | + lambda m: m.weight**2, |
| 116 | + lambda m, v: m._set_weight(v), |
| 117 | + ) |
| 118 | + self.to(train_X) |
| 119 | + |
| 120 | + @property |
| 121 | + def weight(self) -> Tensor: |
| 122 | + return self.raw_weight_constraint.transform(self.raw_weight) |
| 123 | + |
| 124 | + @weight.setter |
| 125 | + def weight(self, value: Tensor) -> None: |
| 126 | + value = torch.as_tensor(value).to(self.raw_weight) |
| 127 | + self.initialize(raw_weight=self.raw_weight_constraint.inverse_transform(value)) |
| 128 | + |
| 129 | + def _set_weight(self, value: Tensor) -> None: |
| 130 | + # Prior closure: prior is registered on `weight ** 2`, so `value` here is a |
| 131 | + # sample of `weight ** 2`. Take the sqrt to recover `weight` before setting. |
| 132 | + self.weight = torch.as_tensor(value).to(self.raw_weight).sqrt() |
| 133 | + |
| 134 | + def train(self, mode: bool = True) -> None: |
| 135 | + r"""Puts the model in `train` mode.""" |
| 136 | + self.training = mode |
| 137 | + for module in self.children(): |
| 138 | + if module is self.base_model_dict: |
| 139 | + # set base_model module training always be False |
| 140 | + module.training = False |
| 141 | + module.requires_grad_(False) |
| 142 | + else: |
| 143 | + module.train(mode) |
| 144 | + |
| 145 | + def forward(self, x: Tensor) -> MultivariateNormal: |
| 146 | + if self.training: |
| 147 | + x = self.transform_inputs(x) |
| 148 | + weighted_means = [] |
| 149 | + weighted_covars = [] |
| 150 | + for i, m in enumerate(self.base_model_dict.values()): |
| 151 | + posterior = m.posterior(x) |
| 152 | + if abs(self.weight[i]) < WEIGHT_THRESHOLD: # Or some appropriate threshold |
| 153 | + continue |
| 154 | + if isinstance(posterior, GaussianMixturePosterior): |
| 155 | + mean = posterior.mixture_mean |
| 156 | + covar = posterior.mvn.covariance_matrix.mean(dim=MCMC_DIM) |
| 157 | + else: |
| 158 | + mean = posterior.mean |
| 159 | + covar = posterior.mvn.covariance_matrix |
| 160 | + weighted_means.append(self.weight[i] * mean) |
| 161 | + weighted_covars.append(covar * (self.weight[i] ** 2)) |
| 162 | + # obtain mean and covar from the offset function |
| 163 | + weighted_means.append(self.mean_module(x).unsqueeze(-1)) |
| 164 | + weighted_covars.append(self.covar_module(x)) |
| 165 | + # average across a list of posteriors |
| 166 | + mean_x = torch.stack(weighted_means).sum(dim=0).squeeze(-1) |
| 167 | + covar_x = PsdSumLinearOperator(*weighted_covars) |
| 168 | + return MultivariateNormal(mean_x, covar_x) |
| 169 | + |
| 170 | + @classmethod |
| 171 | + def construct_inputs( |
| 172 | + cls, |
| 173 | + training_data: SupervisedDataset, |
| 174 | + base_model_dict: dict[str, GPyTorchModel], |
| 175 | + covar_module: Module | None = None, |
| 176 | + mean_module: Mean | None = None, |
| 177 | + ensemble_weight_prior: Prior | None = None, |
| 178 | + ) -> dict[str, Any]: |
| 179 | + r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`. |
| 180 | +
|
| 181 | + Args: |
| 182 | + training_data: A `SupervisedDataset` containing the training data for the |
| 183 | + target task only. |
| 184 | + base_model_dict: Dict of GP models that each corresponds to a model trained |
| 185 | + on an auxiliary dataset. Keys are the name of auxiliary dataset. |
| 186 | + covar_module: The module computing the covariance (Kernel) matrix for the |
| 187 | + target data. If omitted, use a `MaternKernel`. |
| 188 | + mean_module: The mean function to be used for the target data. If omitted, |
| 189 | + use a `ConstantMean`. |
| 190 | + ensemble_weight_prior: The prior over the weights of the ensemble model. |
| 191 | + If omitted, use default HalfCauchyPrior with scale = 1.0. |
| 192 | + """ |
| 193 | + base_inputs = super().construct_inputs(training_data=training_data) |
| 194 | + return { |
| 195 | + **base_inputs, |
| 196 | + "base_model_dict": base_model_dict, |
| 197 | + "covar_module": covar_module, |
| 198 | + "mean_module": mean_module, |
| 199 | + "ensemble_weight_prior": ensemble_weight_prior, |
| 200 | + } |
0 commit comments