Skip to content

Commit cd70192

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Move Target Aware GP to OSS (#3013)
Summary: Pull Request resolved: #3013 Moves TargetAwareGP to OSS BoTorch and updates internal references. The model has been published in https://dl.acm.org/doi/10.1145/3690624.3709419 Differential Revision: D82453481
1 parent 75fe56b commit cd70192

4 files changed

Lines changed: 477 additions & 0 deletions

File tree

botorch/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from botorch.models.model_list_gp_regression import ModelListGP
3030
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
3131
from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood
32+
from botorch.models.target_aware_gp import TargetAwareEnsembleGP
3233

3334
__all__ = [
3435
"add_saas_prior",
@@ -52,4 +53,5 @@
5253
"SingleTaskGP",
5354
"SingleTaskMultiFidelityGP",
5455
"SingleTaskVariationalGP",
56+
"TargetAwareEnsembleGP",
5557
]

botorch/models/target_aware_gp.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from typing import Any
10+
11+
import torch
12+
from botorch.models.fully_bayesian import MCMC_DIM
13+
from botorch.models.gp_regression import SingleTaskGP
14+
from botorch.models.gpytorch import GPyTorchModel
15+
from botorch.models.transforms.input import InputTransform
16+
from botorch.models.transforms.outcome import OutcomeTransform
17+
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
18+
from botorch.utils.datasets import SupervisedDataset
19+
from gpytorch.constraints import Interval
20+
from gpytorch.distributions.multivariate_normal import MultivariateNormal
21+
from gpytorch.means.mean import Mean
22+
from gpytorch.priors import HalfCauchyPrior, Prior
23+
from linear_operator.operators import PsdSumLinearOperator
24+
from torch import Tensor
25+
from torch.nn import Module, Parameter
26+
from torch.nn.modules import ModuleDict
27+
28+
r"""
29+
Target-aware GP models.
30+
31+
References
32+
33+
.. [Feng2025experimenting]
34+
Q. Feng, S. Daulton, B. Letham, M. Balandat, and E. Bakshy.
35+
Experimenting, Fast and Slow: Bayesian Optimization of Long-term Outcomes
36+
with Online Experiments. Proceedings of the 31st ACM SIGKDD Conference on
37+
Knowledge Discovery and Data Mining, 2025.
38+
"""
39+
40+
WEIGHT_THRESHOLD = 0.01
41+
42+
43+
class TargetAwareEnsembleGP(SingleTaskGP):
44+
r"""A target-aware ensemble GP that takes a set of GPs pre-trained on the
45+
auxiliary data to improve prediction accuracy of the target task.
46+
47+
The target outputs are modeled as weighted sum of an offset function and
48+
functions of each auxiliary sources. The offset function is unknown.
49+
The ensemble model jointly optimizes the kernel hyperparameters of the
50+
target task together with the weights.
51+
52+
This model is described in [Feng2025experimenting]_.
53+
"""
54+
55+
def __init__(
56+
self,
57+
train_X: Tensor,
58+
train_Y: Tensor,
59+
base_model_dict: dict[str, GPyTorchModel],
60+
train_Yvar: Tensor | None = None,
61+
covar_module: Module | None = None,
62+
mean_module: Mean | None = None,
63+
outcome_transform: OutcomeTransform | None = None,
64+
input_transform: InputTransform | None = None,
65+
ensemble_weight_prior: Prior | None = None,
66+
) -> None:
67+
r"""A dictionary kernel with a scale kernel.
68+
69+
Args:
70+
train_X: (n x d) X training data of target task.
71+
train_Y: (n x 1) Y training data of target task.
72+
train_Yvar: (n x 1) Noise variances of each training Y.
73+
base_model_dict: Dict of GP models that each corresponds to a model trained
74+
on an auxiliary dataset. Keys are the name of auxiliary dataset.
75+
covar_module: The module computing the covariance (Kernel) matrix for the
76+
target data. If omitted, use a `MaternKernel`.
77+
mean_module: The mean function to be used for the target data. If omitted,
78+
use a `ConstantMean`.
79+
ensemble_weight_prior: The prior over the weights of the ensemble model.
80+
If omitted, use default HalfCauchyPrior with scale = 1.0.
81+
"""
82+
if ensemble_weight_prior is None:
83+
ensemble_weight_prior = HalfCauchyPrior(scale=1.0)
84+
85+
super().__init__(
86+
train_X=train_X,
87+
train_Y=train_Y,
88+
train_Yvar=train_Yvar,
89+
covar_module=covar_module,
90+
mean_module=mean_module,
91+
outcome_transform=outcome_transform,
92+
input_transform=input_transform,
93+
)
94+
95+
self.base_model_dict = ModuleDict(base_model_dict)
96+
self.base_model_dict.eval()
97+
98+
# register ensemble weights
99+
self.register_parameter(
100+
name="raw_weight",
101+
parameter=Parameter(
102+
torch.zeros(len(self.base_model_dict), device=train_X.device)
103+
),
104+
)
105+
self.register_constraint(
106+
param_name="raw_weight",
107+
constraint=Interval(-10, 10, initial_value=0.1),
108+
replace=True,
109+
)
110+
# set prior on weights so that the unimportant auxiliary
111+
# sources can be shrunk to 0.
112+
self.register_prior(
113+
"weight_prior",
114+
ensemble_weight_prior.to(train_X),
115+
lambda m: m.weight**2,
116+
lambda m, v: m._set_weight(v),
117+
)
118+
self.to(train_X)
119+
120+
@property
121+
def weight(self) -> Tensor:
122+
return self.raw_weight_constraint.transform(self.raw_weight)
123+
124+
@weight.setter
125+
def weight(self, value: Tensor) -> None:
126+
value = torch.as_tensor(value).to(self.raw_weight)
127+
self.initialize(raw_weight=self.raw_weight_constraint.inverse_transform(value))
128+
129+
def _set_weight(self, value: Tensor) -> None:
130+
# Prior closure: prior is registered on `weight ** 2`, so `value` here is a
131+
# sample of `weight ** 2`. Take the sqrt to recover `weight` before setting.
132+
self.weight = torch.as_tensor(value).to(self.raw_weight).sqrt()
133+
134+
def train(self, mode: bool = True) -> None:
135+
r"""Puts the model in `train` mode."""
136+
self.training = mode
137+
for module in self.children():
138+
if module is self.base_model_dict:
139+
# set base_model module training always be False
140+
module.training = False
141+
module.requires_grad_(False)
142+
else:
143+
module.train(mode)
144+
145+
def forward(self, x: Tensor) -> MultivariateNormal:
146+
if self.training:
147+
x = self.transform_inputs(x)
148+
weighted_means = []
149+
weighted_covars = []
150+
for i, m in enumerate(self.base_model_dict.values()):
151+
posterior = m.posterior(x)
152+
if abs(self.weight[i]) < WEIGHT_THRESHOLD: # Or some appropriate threshold
153+
continue
154+
if isinstance(posterior, GaussianMixturePosterior):
155+
mean = posterior.mixture_mean
156+
covar = posterior.mvn.covariance_matrix.mean(dim=MCMC_DIM)
157+
else:
158+
mean = posterior.mean
159+
covar = posterior.mvn.covariance_matrix
160+
weighted_means.append(self.weight[i] * mean)
161+
weighted_covars.append(covar * (self.weight[i] ** 2))
162+
# obtain mean and covar from the offset function
163+
weighted_means.append(self.mean_module(x).unsqueeze(-1))
164+
weighted_covars.append(self.covar_module(x))
165+
# average across a list of posteriors
166+
mean_x = torch.stack(weighted_means).sum(dim=0).squeeze(-1)
167+
covar_x = PsdSumLinearOperator(*weighted_covars)
168+
return MultivariateNormal(mean_x, covar_x)
169+
170+
@classmethod
171+
def construct_inputs(
172+
cls,
173+
training_data: SupervisedDataset,
174+
base_model_dict: dict[str, GPyTorchModel],
175+
covar_module: Module | None = None,
176+
mean_module: Mean | None = None,
177+
ensemble_weight_prior: Prior | None = None,
178+
) -> dict[str, Any]:
179+
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
180+
181+
Args:
182+
training_data: A `SupervisedDataset` containing the training data for the
183+
target task only.
184+
base_model_dict: Dict of GP models that each corresponds to a model trained
185+
on an auxiliary dataset. Keys are the name of auxiliary dataset.
186+
covar_module: The module computing the covariance (Kernel) matrix for the
187+
target data. If omitted, use a `MaternKernel`.
188+
mean_module: The mean function to be used for the target data. If omitted,
189+
use a `ConstantMean`.
190+
ensemble_weight_prior: The prior over the weights of the ensemble model.
191+
If omitted, use default HalfCauchyPrior with scale = 1.0.
192+
"""
193+
base_inputs = super().construct_inputs(training_data=training_data)
194+
return {
195+
**base_inputs,
196+
"base_model_dict": base_model_dict,
197+
"covar_module": covar_module,
198+
"mean_module": mean_module,
199+
"ensemble_weight_prior": ensemble_weight_prior,
200+
}

sphinx/source/models.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ Relevance Pursuit Models
114114
.. automodule:: botorch.models.relevance_pursuit
115115
:members:
116116

117+
Target-Aware GP Models
118+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
119+
.. automodule:: botorch.models.target_aware_gp
120+
:members:
121+
117122
Sparse Axis-Aligned Subspaces (SAAS) GP Models
118123
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
119124
.. automodule:: botorch.models.map_saas

0 commit comments

Comments
 (0)