Skip to content

Commit 9c4a8a4

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Generic FullyBayesianMultiTaskGP (meta-pytorch#3180)
Summary: Creates generic FullyBayesianMultiTaskGP implementation and refactors the existing SAAS model as a special case Differential Revision: D93037430
1 parent 81ff5f9 commit 9c4a8a4

5 files changed

Lines changed: 110 additions & 37 deletions

File tree

botorch/fit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from botorch.models import SingleTaskGP
2222
from botorch.models.approximate_gp import ApproximateGPyTorchModel
2323
from botorch.models.fully_bayesian import AbstractFullyBayesianSingleTaskGP
24-
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
24+
from botorch.models.fully_bayesian_multitask import FullyBayesianMultiTaskGP
2525
from botorch.models.map_saas import get_map_saas_model
2626
from botorch.models.model_list_gp_regression import ModelListGP
2727
from botorch.models.transforms.input import InputTransform
@@ -334,7 +334,7 @@ def _fit_fallback_approximate(
334334

335335

336336
def fit_fully_bayesian_model_nuts(
337-
model: AbstractFullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP,
337+
model: AbstractFullyBayesianSingleTaskGP | FullyBayesianMultiTaskGP,
338338
max_tree_depth: int = 6,
339339
warmup_steps: int = 512,
340340
num_samples: int = 256,

botorch/models/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
PosteriorMeanModel,
1616
)
1717
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
18-
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
18+
from botorch.models.fully_bayesian_multitask import (
19+
FullyBayesianMultiTaskGP,
20+
SaasFullyBayesianMultiTaskGP,
21+
)
1922
from botorch.models.gp_regression import SingleTaskGP
2023
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
2124
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
@@ -37,6 +40,7 @@
3740
"AffineFidelityCostModel",
3841
"ApproximateGPyTorchModel",
3942
"EnsembleMapSaasSingleTaskGP",
43+
"FullyBayesianMultiTaskGP",
4044
"GenericDeterministicModel",
4145
"HigherOrderGP",
4246
"KroneckerMultiTaskGP",

botorch/models/fully_bayesian.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def set_inputs(
169169
train_X: Tensor,
170170
train_Y: Tensor,
171171
train_Yvar: Tensor | None = None,
172-
task_feature: int | None = None,
173-
task_rank: int | None = None,
172+
task_feature: int | None = None, # noqa: ARG002
173+
task_rank: int | None = None, # noqa: ARG002
174174
) -> None:
175175
"""Set the training data.
176176

botorch/models/fully_bayesian_multitask.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
matern52_kernel,
1717
MCMC_DIM,
1818
MIN_INFERRED_NOISE_LEVEL,
19+
PyroModel,
1920
reshape_and_detach,
2021
SaasPyroModel,
2122
)
@@ -39,8 +40,8 @@
3940
from typing_extensions import Self
4041

4142
# Can replace with Self type once 3.11 is the minimum version
42-
TSaasFullyBayesianMultiTaskGP = TypeVar(
43-
"TSaasFullyBayesianMultiTaskGP", bound="SaasFullyBayesianMultiTaskGP"
43+
TFullyBayesianMultiTaskGP = TypeVar(
44+
"TFullyBayesianMultiTaskGP", bound="FullyBayesianMultiTaskGP"
4445
)
4546

4647

@@ -260,19 +261,20 @@ def get_dummy_mcmc_samples(
260261

261262
class MultitaskSaasPyroModel(LatentFeatureMultiTaskPyroMixin, SaasPyroModel):
262263
r"""
263-
Multi-task SAAS model. Backward-compatible subclass that composes
264-
``LatentFeatureMultiTaskPyroMixin`` with ``SaasPyroModel``.
264+
Multi-task SAAS model using latent task features. Backward-compatible
265+
subclass that composes ``LatentFeatureMultiTaskPyroMixin`` with
266+
``SaasPyroModel``.
265267
"""
266268

267269
pass
268270

269271

270-
class SaasFullyBayesianMultiTaskGP(MultiTaskGP):
271-
r"""A fully Bayesian multi-task GP model with the SAAS prior.
272+
class FullyBayesianMultiTaskGP(MultiTaskGP):
273+
r"""A fully Bayesian multi-task GP model.
274+
272275
This model assumes that the inputs have been normalized to [0, 1]^d and that the
273276
output has been stratified standardized to have zero mean and unit variance for
274-
each task. The SAAS model [Eriksson2021saasbo]_ with a Matern-5/2 is used as data
275-
kernel by default.
277+
each task.
276278
277279
You are expected to use ``fit_fully_bayesian_model_nuts`` to fit this model as it
278280
isn't compatible with ``fit_gpytorch_mll``.
@@ -285,11 +287,12 @@ class SaasFullyBayesianMultiTaskGP(MultiTaskGP):
285287
>>> ])
286288
>>> train_Y = torch.cat(f1(X1), f2(X2)).unsqueeze(-1)
287289
>>> train_Yvar = 0.01 * torch.ones_like(train_Y)
288-
>>> mtsaas_gp = SaasFullyBayesianMultiTaskGP(
289-
>>> train_X, train_Y, train_Yvar, task_feature=-1,
290+
>>> mt_gp = FullyBayesianMultiTaskGP(
291+
>>> train_X, train_Y, task_feature=-1,
292+
>>> pyro_model=MultitaskSaasPyroModel(),
290293
>>> )
291-
>>> fit_fully_bayesian_model_nuts(mtsaas_gp)
292-
>>> posterior = mtsaas_gp.posterior(test_X)
294+
>>> fit_fully_bayesian_model_nuts(mt_gp)
295+
>>> posterior = mt_gp.posterior(test_X)
293296
"""
294297

295298
_is_fully_bayesian = True
@@ -306,7 +309,7 @@ def __init__(
306309
all_tasks: list[int] | None = None,
307310
outcome_transform: OutcomeTransform | None = None,
308311
input_transform: InputTransform | None = None,
309-
pyro_model: MultitaskSaasPyroModel | None = None,
312+
pyro_model: PyroModel | None = None,
310313
validate_task_values: bool = True,
311314
) -> None:
312315
r"""Initialize the fully Bayesian multi-task GP model.
@@ -333,8 +336,7 @@ def __init__(
333336
instantiation of the model.
334337
input_transform: An input transform that is applied to the inputs ``X``
335338
in the model's forward pass.
336-
pyro_model: Optional ``PyroModel`` that has the same signature as
337-
``MultitaskSaasPyroModel``. Defaults to ``MultitaskSaasPyroModel``.
339+
pyro_model: A ``PyroModel`` that inherits from ``MultiTaskPyroMixin``.
338340
validate_task_values: If True, validate that the task values supplied in the
339341
input are expected tasks values. If false, unexpected task values
340342
will be mapped to the first output_task if supplied.
@@ -384,7 +386,8 @@ def __init__(
384386
self.likelihood = None
385387
if pyro_model is None:
386388
pyro_model = MultitaskSaasPyroModel()
387-
# apply task_mapper
389+
if not isinstance(pyro_model, MultiTaskPyroMixin):
390+
raise ValueError("pyro_model must be a multi-task model.")
388391
x_before, task_idcs, x_after = self._split_inputs(transformed_X)
389392
pyro_model.set_inputs(
390393
train_X=torch.cat([x_before, task_idcs, x_after], dim=-1),
@@ -394,15 +397,13 @@ def __init__(
394397
task_rank=self._rank,
395398
all_tasks=all_tasks,
396399
)
397-
self.pyro_model: MultitaskSaasPyroModel = pyro_model
400+
self.pyro_model: PyroModel = pyro_model
398401
if outcome_transform is not None:
399402
self.outcome_transform = outcome_transform
400403
if input_transform is not None:
401404
self.input_transform = input_transform
402405

403-
def train(
404-
self, mode: bool = True, reset: bool = True
405-
) -> TSaasFullyBayesianMultiTaskGP:
406+
def train(self, mode: bool = True, reset: bool = True) -> TFullyBayesianMultiTaskGP:
406407
r"""Puts the model in ``train`` mode.
407408
408409
Args:
@@ -436,7 +437,7 @@ def num_mcmc_samples(self) -> int:
436437
@property
437438
def batch_shape(self) -> torch.Size:
438439
r"""Batch shape of the model, equal to the number of MCMC samples.
439-
Note that ``SaasFullyBayesianMultiTaskGP`` does not support batching
440+
Note that ``FullyBayesianMultiTaskGP`` does not support batching
440441
over input data at this point.
441442
"""
442443
self._check_if_fitted()
@@ -513,22 +514,14 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
513514
r"""Custom logic for loading the state dict.
514515
515516
The standard approach of calling ``load_state_dict`` currently doesn't
516-
play well with the ``SaasFullyBayesianMultiTaskGP`` since the
517+
play well with the ``FullyBayesianMultiTaskGP`` since the
517518
``mean_module``, ``covar_module`` and ``likelihood`` aren't initialized
518519
until the model has been fitted. The reason for this is that we don't
519520
know the number of MCMC samples until NUTS is called. Given the state
520521
dict, we can initialize a new model with some dummy samples and then
521-
load the state dict into this model. This currently only works for a
522-
``MultitaskSaasPyroModel`` and supporting more Pyro models likely
523-
requires moving the model construction logic into the Pyro model itself.
524-
525-
TODO: If this were to inherit from ``SaasFullyBayesianSingleTaskGP``, we could
526-
simplify this method and eliminate some others.
522+
load the state dict into this model. The dummy samples are obtained
523+
from ``pyro_model.get_dummy_mcmc_samples()``.
527524
"""
528-
if not isinstance(self.pyro_model, MultitaskSaasPyroModel):
529-
raise NotImplementedError( # pragma: no cover
530-
"load_state_dict only works for MultitaskSaasPyroModel"
531-
)
532525
raw_mean = state_dict["mean_module.base_means.0.raw_constant"]
533526
num_mcmc_samples = len(raw_mean)
534527
tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype}
@@ -572,3 +565,7 @@ def condition_on_observations(
572565
X = X.repeat(*(Y.shape[:-2] + (1, 1)))
573566

574567
return super().condition_on_observations(X, Y, **kwargs)
568+
569+
570+
class SaasFullyBayesianMultiTaskGP(FullyBayesianMultiTaskGP):
571+
r"""A fully Bayesian multi-task GP model with the SAAS prior by default."""

test/models/test_fully_bayesian_multitask.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
SaasPyroModel,
4040
)
4141
from botorch.models.fully_bayesian_multitask import (
42+
FullyBayesianMultiTaskGP,
4243
LatentFeatureMultiTaskPyroMixin,
4344
MultiTaskPyroMixin,
4445
MultitaskSaasPyroModel,
@@ -915,6 +916,77 @@ def test_construct_inputs(self):
915916
self.assertEqual(model._task_feature, d)
916917
self.assertEqual(model.pyro_model.task_feature, d)
917918

919+
def test_constructor_validation_and_defaults(self):
920+
"""Test FullyBayesianMultiTaskGP constructor validation and defaults."""
921+
tkwargs = {"device": self.device, "dtype": torch.double}
922+
train_X, train_Y, train_Yvar = self._get_base_data(**tkwargs)
923+
924+
with self.subTest("rejects_single_task_pyro_model"):
925+
with self.assertRaisesRegex(
926+
ValueError, "pyro_model must be a multi-task model"
927+
):
928+
FullyBayesianMultiTaskGP(
929+
train_X=train_X,
930+
train_Y=train_Y,
931+
train_Yvar=train_Yvar,
932+
task_feature=4,
933+
pyro_model=MaternPyroModel(),
934+
)
935+
936+
with self.subTest("accepts_explicit_multitask_pyro_model"):
937+
pyro_model = MultitaskSaasPyroModel()
938+
model = FullyBayesianMultiTaskGP(
939+
train_X=train_X,
940+
train_Y=train_Y,
941+
train_Yvar=train_Yvar,
942+
task_feature=4,
943+
pyro_model=pyro_model,
944+
)
945+
self.assertIs(model.pyro_model, pyro_model)
946+
self.assertIsInstance(model.pyro_model, MultiTaskPyroMixin)
947+
948+
def test_non_saas_mt_model_load_state_dict(self):
949+
"""Test round-trip load_state_dict with a non-SAAS multi-task PyroModel."""
950+
tkwargs = {"device": self.device, "dtype": torch.double}
951+
952+
class MultitaskMaternPyroModel(
953+
LatentFeatureMultiTaskPyroMixin, MaternPyroModel
954+
):
955+
pass
956+
957+
train_X, train_Y, train_Yvar = self._get_base_data(**tkwargs)
958+
959+
pyro_model = MultitaskMaternPyroModel()
960+
model = FullyBayesianMultiTaskGP(
961+
train_X=train_X,
962+
train_Y=train_Y,
963+
train_Yvar=train_Yvar,
964+
task_feature=4,
965+
pyro_model=pyro_model,
966+
)
967+
self.assertIsInstance(model.pyro_model, MultiTaskPyroMixin)
968+
969+
fit_fully_bayesian_model_nuts(
970+
model, warmup_steps=8, num_samples=5, thinning=2, disable_progbar=True
971+
)
972+
state_dict = model.state_dict()
973+
test_X = torch.rand(3, 4, **tkwargs)
974+
preds1 = model.posterior(test_X)
975+
976+
pyro_model2 = MultitaskMaternPyroModel()
977+
m_new = FullyBayesianMultiTaskGP(
978+
train_X=train_X,
979+
train_Y=train_Y,
980+
train_Yvar=train_Yvar,
981+
task_feature=4,
982+
pyro_model=pyro_model2,
983+
)
984+
m_new.load_state_dict(state_dict)
985+
986+
preds2 = m_new.posterior(test_X)
987+
self.assertTrue(torch.equal(preds1.mean, preds2.mean))
988+
self.assertTrue(torch.equal(preds1.variance, preds2.variance))
989+
918990

919991
class TestPyroModelMultitaskMixin(BotorchTestCase):
920992
"""Tests for the MultiTaskPyroMixin and LatentFeatureMultiTaskPyroMixin classes."""

0 commit comments

Comments
 (0)