Skip to content

Commit 8634545

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Split into FullyBayesianMultiTaskGP + SaasFullyBayesianMultiTaskGP, generalize to accept any PyroModel (meta-pytorch#3180)
Summary: Corresponds to D92844565. This diff also absorbs the 'generalize GP to accept any PyroModel' work from the now-dropped D92844566. (1) Rename the current `SaasFullyBayesianMultiTaskGP` → `FullyBayesianMultiTaskGP` (public name from the start, no private-then-rename dance). (2) Create `SaasFullyBayesianMultiTaskGP(FullyBayesianMultiTaskGP)` that defaults `pyro_model=MultitaskSaasPyroModel()`. (3) Widen `pyro_model` type to `PyroModel`, validate `is_multitask` at construction time: `if not pyro_model.is_multitask: raise ValueError(...)` — the one appropriate use of `is_multitask` (type guard at construction time). (4) Replace hardcoded dummy samples in `load_state_dict` with `pyro_model.get_dummy_mcmc_samples()` — zero hardcoded keys. (5) Base class raises `ValueError` if `pyro_model is None`. `SaasFullyBayesianMultiTaskGP` sets the default. Files to modify: - `fbcode/pytorch/botorch/botorch/models/fully_bayesian_multitask.py` — split class, widen types, delegate load_state_dict - `fbcode/pytorch/botorch/test/models/test_fully_bayesian_multitask.py` — test base with explicit PyroModel, test is_multitask validation, test MaternPyroModel-based MT model MANDATORY CODING STANDARDS (apply to every line of code in this diff): 1. SELF-CONTAINED DIFFS: Every diff must be entirely self-contained. Every line of production code introduced in this diff MUST be covered by tests that are introduced in the SAME diff. Never create code without its corresponding tests, and never create tests in a separate diff from the code they cover. 2. NO TRY/EXCEPT: Do NOT use try/except statements anywhere, under any circumstances. Handle errors through proper control flow, type checking, validation, and return values instead. If existing code uses try/except, that is okay - it does not need to be changed. 3. NO DUPLICATE CODE: Aggressively eliminate duplicate code. Extract shared logic into helper functions, base classes, or utilities. When you believe this item is complete, launch a subagent whose sole job is to audit the diff for duplicate code and refactor it away before finalizing. 4. SIMPLEST REUSABLE SOLUTION: Always choose the simplest, most reusable solution. Prefer composition over inheritance, small focused functions over large ones, and generic utilities over one-off implementations. However, do not excessively use thin wrapper functions, as these unnecessarily increase the number of lines of code. 5. CONCISE DIFF SUMMARY: The diff summary must explain WHY the change is implemented, not enumerate methods added. The summary must be at most three precise sentences describing the functionality and motivation. If it is longer than three sentences, you must shorten the summary. Differential Revision: D93037430
1 parent 6360ee6 commit 8634545

4 files changed

Lines changed: 105 additions & 33 deletions

File tree

botorch/fit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from botorch.models import SingleTaskGP
2222
from botorch.models.approximate_gp import ApproximateGPyTorchModel
2323
from botorch.models.fully_bayesian import AbstractFullyBayesianSingleTaskGP
24-
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
24+
from botorch.models.fully_bayesian_multitask import FullyBayesianMultiTaskGP
2525
from botorch.models.map_saas import get_map_saas_model
2626
from botorch.models.model_list_gp_regression import ModelListGP
2727
from botorch.models.transforms.input import InputTransform
@@ -334,7 +334,7 @@ def _fit_fallback_approximate(
334334

335335

336336
def fit_fully_bayesian_model_nuts(
337-
model: AbstractFullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP,
337+
model: AbstractFullyBayesianSingleTaskGP | FullyBayesianMultiTaskGP,
338338
max_tree_depth: int = 6,
339339
warmup_steps: int = 512,
340340
num_samples: int = 256,

botorch/models/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
PosteriorMeanModel,
1616
)
1717
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
18-
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
18+
from botorch.models.fully_bayesian_multitask import (
19+
FullyBayesianMultiTaskGP,
20+
SaasFullyBayesianMultiTaskGP,
21+
)
1922
from botorch.models.gp_regression import SingleTaskGP
2023
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
2124
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
@@ -37,6 +40,7 @@
3740
"AffineFidelityCostModel",
3841
"ApproximateGPyTorchModel",
3942
"EnsembleMapSaasSingleTaskGP",
43+
"FullyBayesianMultiTaskGP",
4044
"GenericDeterministicModel",
4145
"HigherOrderGP",
4246
"KroneckerMultiTaskGP",

botorch/models/fully_bayesian_multitask.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
matern52_kernel,
1717
MCMC_DIM,
1818
MIN_INFERRED_NOISE_LEVEL,
19+
PyroModel,
1920
reshape_and_detach,
2021
SaasPyroModel,
2122
)
@@ -39,8 +40,8 @@
3940
from typing_extensions import Self
4041

4142
# Can replace with Self type once 3.11 is the minimum version
42-
TSaasFullyBayesianMultiTaskGP = TypeVar(
43-
"TSaasFullyBayesianMultiTaskGP", bound="SaasFullyBayesianMultiTaskGP"
43+
TFullyBayesianMultiTaskGP = TypeVar(
44+
"TFullyBayesianMultiTaskGP", bound="FullyBayesianMultiTaskGP"
4445
)
4546

4647

@@ -267,12 +268,12 @@ class MultitaskSaasPyroModel(LatentFeatureMultiTaskPyroMixin, SaasPyroModel):
267268
pass
268269

269270

270-
class SaasFullyBayesianMultiTaskGP(MultiTaskGP):
271-
r"""A fully Bayesian multi-task GP model with the SAAS prior.
271+
class FullyBayesianMultiTaskGP(MultiTaskGP):
272+
r"""A fully Bayesian multi-task GP model.
273+
272274
This model assumes that the inputs have been normalized to [0, 1]^d and that the
273275
output has been stratified standardized to have zero mean and unit variance for
274-
each task. The SAAS model [Eriksson2021saasbo]_ with a Matern-5/2 is used as data
275-
kernel by default.
276+
each task.
276277
277278
You are expected to use ``fit_fully_bayesian_model_nuts`` to fit this model as it
278279
isn't compatible with ``fit_gpytorch_mll``.
@@ -285,11 +286,12 @@ class SaasFullyBayesianMultiTaskGP(MultiTaskGP):
285286
>>> ])
286287
>>> train_Y = torch.cat(f1(X1), f2(X2)).unsqueeze(-1)
287288
>>> train_Yvar = 0.01 * torch.ones_like(train_Y)
288-
>>> mtsaas_gp = SaasFullyBayesianMultiTaskGP(
289-
>>> train_X, train_Y, train_Yvar, task_feature=-1,
289+
>>> mt_gp = FullyBayesianMultiTaskGP(
290+
>>> train_X, train_Y, task_feature=-1,
291+
>>> pyro_model=MultitaskSaasPyroModel(),
290292
>>> )
291-
>>> fit_fully_bayesian_model_nuts(mtsaas_gp)
292-
>>> posterior = mtsaas_gp.posterior(test_X)
293+
>>> fit_fully_bayesian_model_nuts(mt_gp)
294+
>>> posterior = mt_gp.posterior(test_X)
293295
"""
294296

295297
_is_fully_bayesian = True
@@ -306,7 +308,7 @@ def __init__(
306308
all_tasks: list[int] | None = None,
307309
outcome_transform: OutcomeTransform | None = None,
308310
input_transform: InputTransform | None = None,
309-
pyro_model: MultitaskSaasPyroModel | None = None,
311+
pyro_model: PyroModel | None = None,
310312
validate_task_values: bool = True,
311313
) -> None:
312314
r"""Initialize the fully Bayesian multi-task GP model.
@@ -333,8 +335,7 @@ def __init__(
333335
instantiation of the model.
334336
input_transform: An input transform that is applied to the inputs ``X``
335337
in the model's forward pass.
336-
pyro_model: Optional ``PyroModel`` that has the same signature as
337-
``MultitaskSaasPyroModel``. Defaults to ``MultitaskSaasPyroModel``.
338+
pyro_model: A ``PyroModel`` that inherits from ``MultiTaskPyroMixin``.
338339
validate_task_values: If True, validate that the task values supplied in the
339340
input are expected tasks values. If false, unexpected task values
340341
will be mapped to the first output_task if supplied.
@@ -384,7 +385,8 @@ def __init__(
384385
self.likelihood = None
385386
if pyro_model is None:
386387
pyro_model = MultitaskSaasPyroModel()
387-
# apply task_mapper
388+
if not isinstance(pyro_model, MultiTaskPyroMixin):
389+
raise ValueError("pyro_model must be a multi-task model.")
388390
x_before, task_idcs, x_after = self._split_inputs(transformed_X)
389391
pyro_model.set_inputs(
390392
train_X=torch.cat([x_before, task_idcs, x_after], dim=-1),
@@ -394,15 +396,13 @@ def __init__(
394396
task_rank=self._rank,
395397
all_tasks=all_tasks,
396398
)
397-
self.pyro_model: MultitaskSaasPyroModel = pyro_model
399+
self.pyro_model: PyroModel = pyro_model
398400
if outcome_transform is not None:
399401
self.outcome_transform = outcome_transform
400402
if input_transform is not None:
401403
self.input_transform = input_transform
402404

403-
def train(
404-
self, mode: bool = True, reset: bool = True
405-
) -> TSaasFullyBayesianMultiTaskGP:
405+
def train(self, mode: bool = True, reset: bool = True) -> TFullyBayesianMultiTaskGP:
406406
r"""Puts the model in ``train`` mode.
407407
408408
Args:
@@ -436,7 +436,7 @@ def num_mcmc_samples(self) -> int:
436436
@property
437437
def batch_shape(self) -> torch.Size:
438438
r"""Batch shape of the model, equal to the number of MCMC samples.
439-
Note that ``SaasFullyBayesianMultiTaskGP`` does not support batching
439+
Note that ``FullyBayesianMultiTaskGP`` does not support batching
440440
over input data at this point.
441441
"""
442442
self._check_if_fitted()
@@ -513,22 +513,14 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
513513
r"""Custom logic for loading the state dict.
514514
515515
The standard approach of calling ``load_state_dict`` currently doesn't
516-
play well with the ``SaasFullyBayesianMultiTaskGP`` since the
516+
play well with the ``FullyBayesianMultiTaskGP`` since the
517517
``mean_module``, ``covar_module`` and ``likelihood`` aren't initialized
518518
until the model has been fitted. The reason for this is that we don't
519519
know the number of MCMC samples until NUTS is called. Given the state
520520
dict, we can initialize a new model with some dummy samples and then
521-
load the state dict into this model. This currently only works for a
522-
``MultitaskSaasPyroModel`` and supporting more Pyro models likely
523-
requires moving the model construction logic into the Pyro model itself.
524-
525-
TODO: If this were to inherit from ``SaasFullyBayesianSingleTaskGP``, we could
526-
simplify this method and eliminate some others.
521+
load the state dict into this model. The dummy samples are obtained
522+
from ``pyro_model.get_dummy_mcmc_samples()``.
527523
"""
528-
if not isinstance(self.pyro_model, MultitaskSaasPyroModel):
529-
raise NotImplementedError( # pragma: no cover
530-
"load_state_dict only works for MultitaskSaasPyroModel"
531-
)
532524
raw_mean = state_dict["mean_module.base_means.0.raw_constant"]
533525
num_mcmc_samples = len(raw_mean)
534526
tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype}
@@ -572,3 +564,7 @@ def condition_on_observations(
572564
X = X.repeat(*(Y.shape[:-2] + (1, 1)))
573565

574566
return super().condition_on_observations(X, Y, **kwargs)
567+
568+
569+
class SaasFullyBayesianMultiTaskGP(FullyBayesianMultiTaskGP):
570+
r"""A fully Bayesian multi-task GP model with the SAAS prior by default."""

test/models/test_fully_bayesian_multitask.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
SaasPyroModel,
4040
)
4141
from botorch.models.fully_bayesian_multitask import (
42+
FullyBayesianMultiTaskGP,
4243
LatentFeatureMultiTaskPyroMixin,
4344
MultiTaskPyroMixin,
4445
MultitaskSaasPyroModel,
@@ -915,6 +916,77 @@ def test_construct_inputs(self):
915916
self.assertEqual(model._task_feature, d)
916917
self.assertEqual(model.pyro_model.task_feature, d)
917918

919+
def test_constructor_validation_and_defaults(self):
920+
"""Test FullyBayesianMultiTaskGP constructor validation and defaults."""
921+
tkwargs = {"device": self.device, "dtype": torch.double}
922+
train_X, train_Y, train_Yvar = self._get_base_data(**tkwargs)
923+
924+
with self.subTest("rejects_single_task_pyro_model"):
925+
with self.assertRaisesRegex(
926+
ValueError, "pyro_model must be a multi-task model"
927+
):
928+
FullyBayesianMultiTaskGP(
929+
train_X=train_X,
930+
train_Y=train_Y,
931+
train_Yvar=train_Yvar,
932+
task_feature=4,
933+
pyro_model=MaternPyroModel(),
934+
)
935+
936+
with self.subTest("accepts_explicit_multitask_pyro_model"):
937+
pyro_model = MultitaskSaasPyroModel()
938+
model = FullyBayesianMultiTaskGP(
939+
train_X=train_X,
940+
train_Y=train_Y,
941+
train_Yvar=train_Yvar,
942+
task_feature=4,
943+
pyro_model=pyro_model,
944+
)
945+
self.assertIs(model.pyro_model, pyro_model)
946+
self.assertIsInstance(model.pyro_model, MultiTaskPyroMixin)
947+
948+
def test_non_saas_mt_model_load_state_dict(self):
949+
"""Test round-trip load_state_dict with a non-SAAS multi-task PyroModel."""
950+
tkwargs = {"device": self.device, "dtype": torch.double}
951+
952+
class MultitaskMaternPyroModel(
953+
LatentFeatureMultiTaskPyroMixin, MaternPyroModel
954+
):
955+
pass
956+
957+
train_X, train_Y, train_Yvar = self._get_base_data(**tkwargs)
958+
959+
pyro_model = MultitaskMaternPyroModel()
960+
model = FullyBayesianMultiTaskGP(
961+
train_X=train_X,
962+
train_Y=train_Y,
963+
train_Yvar=train_Yvar,
964+
task_feature=4,
965+
pyro_model=pyro_model,
966+
)
967+
self.assertIsInstance(model.pyro_model, MultiTaskPyroMixin)
968+
969+
fit_fully_bayesian_model_nuts(
970+
model, warmup_steps=8, num_samples=5, thinning=2, disable_progbar=True
971+
)
972+
state_dict = model.state_dict()
973+
test_X = torch.rand(3, 4, **tkwargs)
974+
preds1 = model.posterior(test_X)
975+
976+
pyro_model2 = MultitaskMaternPyroModel()
977+
m_new = FullyBayesianMultiTaskGP(
978+
train_X=train_X,
979+
train_Y=train_Y,
980+
train_Yvar=train_Yvar,
981+
task_feature=4,
982+
pyro_model=pyro_model2,
983+
)
984+
m_new.load_state_dict(state_dict)
985+
986+
preds2 = m_new.posterior(test_X)
987+
self.assertTrue(torch.equal(preds1.mean, preds2.mean))
988+
self.assertTrue(torch.equal(preds1.variance, preds2.variance))
989+
918990

919991
class TestPyroModelMultitaskMixin(BotorchTestCase):
920992
"""Tests for the MultiTaskPyroMixin and LatentFeatureMultiTaskPyroMixin classes."""

0 commit comments

Comments
 (0)