Skip to content

Commit 309c649

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Generic FullyBayesianMultiTaskGP (#3180)
Summary: Creates generic FullyBayesianMultiTaskGP implementation and refactors the existing SAAS model as a special case Differential Revision: D93037430
1 parent 88138ed commit 309c649

5 files changed

Lines changed: 110 additions & 37 deletions

File tree

botorch/fit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from botorch.models import SingleTaskGP
2222
from botorch.models.approximate_gp import ApproximateGPyTorchModel
2323
from botorch.models.fully_bayesian import AbstractFullyBayesianSingleTaskGP
24-
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
24+
from botorch.models.fully_bayesian_multitask import FullyBayesianMultiTaskGP
2525
from botorch.models.map_saas import get_map_saas_model
2626
from botorch.models.model_list_gp_regression import ModelListGP
2727
from botorch.models.transforms.input import InputTransform
@@ -334,7 +334,7 @@ def _fit_fallback_approximate(
334334

335335

336336
def fit_fully_bayesian_model_nuts(
337-
model: AbstractFullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP,
337+
model: AbstractFullyBayesianSingleTaskGP | FullyBayesianMultiTaskGP,
338338
max_tree_depth: int = 6,
339339
warmup_steps: int = 512,
340340
num_samples: int = 256,

botorch/models/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
PosteriorMeanModel,
1616
)
1717
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
18-
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
18+
from botorch.models.fully_bayesian_multitask import (
19+
FullyBayesianMultiTaskGP,
20+
SaasFullyBayesianMultiTaskGP,
21+
)
1922
from botorch.models.gp_regression import SingleTaskGP
2023
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
2124
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
@@ -37,6 +40,7 @@
3740
"AffineFidelityCostModel",
3841
"ApproximateGPyTorchModel",
3942
"EnsembleMapSaasSingleTaskGP",
43+
"FullyBayesianMultiTaskGP",
4044
"GenericDeterministicModel",
4145
"HigherOrderGP",
4246
"KroneckerMultiTaskGP",

botorch/models/fully_bayesian.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def set_inputs(
169169
train_X: Tensor,
170170
train_Y: Tensor,
171171
train_Yvar: Tensor | None = None,
172-
task_feature: int | None = None,
173-
task_rank: int | None = None,
172+
task_feature: int | None = None, # noqa: ARG002
173+
task_rank: int | None = None, # noqa: ARG002
174174
) -> None:
175175
"""Set the training data.
176176

botorch/models/fully_bayesian_multitask.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
matern52_kernel,
1717
MCMC_DIM,
1818
MIN_INFERRED_NOISE_LEVEL,
19+
PyroModel,
1920
reshape_and_detach,
2021
SaasPyroModel,
2122
)
@@ -39,8 +40,8 @@
3940
from typing_extensions import Self
4041

4142
# Can replace with Self type once 3.11 is the minimum version
42-
TSaasFullyBayesianMultiTaskGP = TypeVar(
43-
"TSaasFullyBayesianMultiTaskGP", bound="SaasFullyBayesianMultiTaskGP"
43+
TFullyBayesianMultiTaskGP = TypeVar(
44+
"TFullyBayesianMultiTaskGP", bound="FullyBayesianMultiTaskGP"
4445
)
4546

4647

@@ -261,19 +262,20 @@ def get_dummy_mcmc_samples(
261262

262263
class MultitaskSaasPyroModel(LatentFeatureMultiTaskPyroMixin, SaasPyroModel):
263264
r"""
264-
Multi-task SAAS model. Backward-compatible subclass that composes
265-
``LatentFeatureMultiTaskPyroMixin`` with ``SaasPyroModel``.
265+
Multi-task SAAS model using latent task features. Backward-compatible
266+
subclass that composes ``LatentFeatureMultiTaskPyroMixin`` with
267+
``SaasPyroModel``.
266268
"""
267269

268270
pass
269271

270272

271-
class SaasFullyBayesianMultiTaskGP(MultiTaskGP):
272-
r"""A fully Bayesian multi-task GP model with the SAAS prior.
273+
class FullyBayesianMultiTaskGP(MultiTaskGP):
274+
r"""A fully Bayesian multi-task GP model.
275+
273276
This model assumes that the inputs have been normalized to [0, 1]^d and that the
274277
output has been stratified standardized to have zero mean and unit variance for
275-
each task. The SAAS model [Eriksson2021saasbo]_ with a Matern-5/2 is used as data
276-
kernel by default.
278+
each task.
277279
278280
You are expected to use ``fit_fully_bayesian_model_nuts`` to fit this model as it
279281
isn't compatible with ``fit_gpytorch_mll``.
@@ -286,11 +288,12 @@ class SaasFullyBayesianMultiTaskGP(MultiTaskGP):
286288
>>> ])
287289
>>> train_Y = torch.cat(f1(X1), f2(X2)).unsqueeze(-1)
288290
>>> train_Yvar = 0.01 * torch.ones_like(train_Y)
289-
>>> mtsaas_gp = SaasFullyBayesianMultiTaskGP(
290-
>>> train_X, train_Y, train_Yvar, task_feature=-1,
291+
>>> mt_gp = FullyBayesianMultiTaskGP(
292+
>>> train_X, train_Y, task_feature=-1,
293+
>>> pyro_model=MultitaskSaasPyroModel(),
291294
>>> )
292-
>>> fit_fully_bayesian_model_nuts(mtsaas_gp)
293-
>>> posterior = mtsaas_gp.posterior(test_X)
295+
>>> fit_fully_bayesian_model_nuts(mt_gp)
296+
>>> posterior = mt_gp.posterior(test_X)
294297
"""
295298

296299
_is_fully_bayesian = True
@@ -307,7 +310,7 @@ def __init__(
307310
all_tasks: list[int] | None = None,
308311
outcome_transform: OutcomeTransform | None = None,
309312
input_transform: InputTransform | None = None,
310-
pyro_model: MultitaskSaasPyroModel | None = None,
313+
pyro_model: PyroModel | None = None,
311314
validate_task_values: bool = True,
312315
) -> None:
313316
r"""Initialize the fully Bayesian multi-task GP model.
@@ -334,8 +337,7 @@ def __init__(
334337
instantiation of the model.
335338
input_transform: An input transform that is applied to the inputs ``X``
336339
in the model's forward pass.
337-
pyro_model: Optional ``PyroModel`` that has the same signature as
338-
``MultitaskSaasPyroModel``. Defaults to ``MultitaskSaasPyroModel``.
340+
pyro_model: A ``PyroModel`` that inherits from ``MultiTaskPyroMixin``.
339341
validate_task_values: If True, validate that the task values supplied in the
340342
input are expected tasks values. If false, unexpected task values
341343
will be mapped to the first output_task if supplied.
@@ -385,7 +387,8 @@ def __init__(
385387
self.likelihood = None
386388
if pyro_model is None:
387389
pyro_model = MultitaskSaasPyroModel()
388-
# apply task_mapper
390+
if not isinstance(pyro_model, MultiTaskPyroMixin):
391+
raise ValueError("pyro_model must be a multi-task model.")
389392
x_before, task_idcs, x_after = self._split_inputs(transformed_X)
390393
pyro_model.set_inputs(
391394
train_X=torch.cat([x_before, task_idcs, x_after], dim=-1),
@@ -395,15 +398,13 @@ def __init__(
395398
task_rank=self._rank,
396399
all_tasks=all_tasks,
397400
)
398-
self.pyro_model: MultitaskSaasPyroModel = pyro_model
401+
self.pyro_model: PyroModel = pyro_model
399402
if outcome_transform is not None:
400403
self.outcome_transform = outcome_transform
401404
if input_transform is not None:
402405
self.input_transform = input_transform
403406

404-
def train(
405-
self, mode: bool = True, reset: bool = True
406-
) -> TSaasFullyBayesianMultiTaskGP:
407+
def train(self, mode: bool = True, reset: bool = True) -> TFullyBayesianMultiTaskGP:
407408
r"""Puts the model in ``train`` mode.
408409
409410
Args:
@@ -437,7 +438,7 @@ def num_mcmc_samples(self) -> int:
437438
@property
438439
def batch_shape(self) -> torch.Size:
439440
r"""Batch shape of the model, equal to the number of MCMC samples.
440-
Note that ``SaasFullyBayesianMultiTaskGP`` does not support batching
441+
Note that ``FullyBayesianMultiTaskGP`` does not support batching
441442
over input data at this point.
442443
"""
443444
self._check_if_fitted()
@@ -514,22 +515,14 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
514515
r"""Custom logic for loading the state dict.
515516
516517
The standard approach of calling ``load_state_dict`` currently doesn't
517-
play well with the ``SaasFullyBayesianMultiTaskGP`` since the
518+
play well with the ``FullyBayesianMultiTaskGP`` since the
518519
``mean_module``, ``covar_module`` and ``likelihood`` aren't initialized
519520
until the model has been fitted. The reason for this is that we don't
520521
know the number of MCMC samples until NUTS is called. Given the state
521522
dict, we can initialize a new model with some dummy samples and then
522-
load the state dict into this model. This currently only works for a
523-
``MultitaskSaasPyroModel`` and supporting more Pyro models likely
524-
requires moving the model construction logic into the Pyro model itself.
525-
526-
TODO: If this were to inherit from ``SaasFullyBayesianSingleTaskGP``, we could
527-
simplify this method and eliminate some others.
523+
load the state dict into this model. The dummy samples are obtained
524+
from ``pyro_model.get_dummy_mcmc_samples()``.
528525
"""
529-
if not isinstance(self.pyro_model, MultitaskSaasPyroModel):
530-
raise NotImplementedError( # pragma: no cover
531-
"load_state_dict only works for MultitaskSaasPyroModel"
532-
)
533526
raw_mean = state_dict["mean_module.base_means.0.raw_constant"]
534527
num_mcmc_samples = len(raw_mean)
535528
tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype}
@@ -573,3 +566,7 @@ def condition_on_observations(
573566
X = X.repeat(*(Y.shape[:-2] + (1, 1)))
574567

575568
return super().condition_on_observations(X, Y, **kwargs)
569+
570+
571+
class SaasFullyBayesianMultiTaskGP(FullyBayesianMultiTaskGP):
572+
r"""A fully Bayesian multi-task GP model with the SAAS prior by default."""

test/models/test_fully_bayesian_multitask.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
SaasPyroModel,
4040
)
4141
from botorch.models.fully_bayesian_multitask import (
42+
FullyBayesianMultiTaskGP,
4243
LatentFeatureMultiTaskPyroMixin,
4344
MultiTaskPyroMixin,
4445
MultitaskSaasPyroModel,
@@ -915,6 +916,77 @@ def test_construct_inputs(self):
915916
self.assertEqual(model._task_feature, d)
916917
self.assertEqual(model.pyro_model.task_feature, d)
917918

919+
def test_constructor_validation_and_defaults(self):
920+
"""Test FullyBayesianMultiTaskGP constructor validation and defaults."""
921+
tkwargs = {"device": self.device, "dtype": torch.double}
922+
train_X, train_Y, train_Yvar = self._get_base_data(**tkwargs)
923+
924+
with self.subTest("rejects_single_task_pyro_model"):
925+
with self.assertRaisesRegex(
926+
ValueError, "pyro_model must be a multi-task model"
927+
):
928+
FullyBayesianMultiTaskGP(
929+
train_X=train_X,
930+
train_Y=train_Y,
931+
train_Yvar=train_Yvar,
932+
task_feature=4,
933+
pyro_model=MaternPyroModel(),
934+
)
935+
936+
with self.subTest("accepts_explicit_multitask_pyro_model"):
937+
pyro_model = MultitaskSaasPyroModel()
938+
model = FullyBayesianMultiTaskGP(
939+
train_X=train_X,
940+
train_Y=train_Y,
941+
train_Yvar=train_Yvar,
942+
task_feature=4,
943+
pyro_model=pyro_model,
944+
)
945+
self.assertIs(model.pyro_model, pyro_model)
946+
self.assertIsInstance(model.pyro_model, MultiTaskPyroMixin)
947+
948+
def test_non_saas_mt_model_load_state_dict(self):
949+
"""Test round-trip load_state_dict with a non-SAAS multi-task PyroModel."""
950+
tkwargs = {"device": self.device, "dtype": torch.double}
951+
952+
class MultitaskMaternPyroModel(
953+
LatentFeatureMultiTaskPyroMixin, MaternPyroModel
954+
):
955+
pass
956+
957+
train_X, train_Y, train_Yvar = self._get_base_data(**tkwargs)
958+
959+
pyro_model = MultitaskMaternPyroModel()
960+
model = FullyBayesianMultiTaskGP(
961+
train_X=train_X,
962+
train_Y=train_Y,
963+
train_Yvar=train_Yvar,
964+
task_feature=4,
965+
pyro_model=pyro_model,
966+
)
967+
self.assertIsInstance(model.pyro_model, MultiTaskPyroMixin)
968+
969+
fit_fully_bayesian_model_nuts(
970+
model, warmup_steps=8, num_samples=5, thinning=2, disable_progbar=True
971+
)
972+
state_dict = model.state_dict()
973+
test_X = torch.rand(3, 4, **tkwargs)
974+
preds1 = model.posterior(test_X)
975+
976+
pyro_model2 = MultitaskMaternPyroModel()
977+
m_new = FullyBayesianMultiTaskGP(
978+
train_X=train_X,
979+
train_Y=train_Y,
980+
train_Yvar=train_Yvar,
981+
task_feature=4,
982+
pyro_model=pyro_model2,
983+
)
984+
m_new.load_state_dict(state_dict)
985+
986+
preds2 = m_new.posterior(test_X)
987+
self.assertTrue(torch.equal(preds1.mean, preds2.mean))
988+
self.assertTrue(torch.equal(preds1.variance, preds2.variance))
989+
918990

919991
class TestPyroModelMultitaskMixin(BotorchTestCase):
920992
"""Tests for the MultiTaskPyroMixin and LatentFeatureMultiTaskPyroMixin classes."""

0 commit comments

Comments
 (0)