1616 matern52_kernel ,
1717 MCMC_DIM ,
1818 MIN_INFERRED_NOISE_LEVEL ,
19+ PyroModel ,
1920 reshape_and_detach ,
2021 SaasPyroModel ,
2122)
3940from typing_extensions import Self
4041
4142# Can replace with Self type once 3.11 is the minimum version
42- TSaasFullyBayesianMultiTaskGP = TypeVar (
43- "TSaasFullyBayesianMultiTaskGP " , bound = "SaasFullyBayesianMultiTaskGP "
43+ TFullyBayesianMultiTaskGP = TypeVar (
44+ "TFullyBayesianMultiTaskGP " , bound = "FullyBayesianMultiTaskGP "
4445)
4546
4647
@@ -261,19 +262,20 @@ def get_dummy_mcmc_samples(
261262
262263class MultitaskSaasPyroModel (LatentFeatureMultiTaskPyroMixin , SaasPyroModel ):
263264 r"""
264- Multi-task SAAS model. Backward-compatible subclass that composes
265- ``LatentFeatureMultiTaskPyroMixin`` with ``SaasPyroModel``.
265+ Multi-task SAAS model using latent task features. Backward-compatible
266+ subclass that composes ``LatentFeatureMultiTaskPyroMixin`` with
267+ ``SaasPyroModel``.
266268 """
267269
268270 pass
269271
270272
271- class SaasFullyBayesianMultiTaskGP (MultiTaskGP ):
272- r"""A fully Bayesian multi-task GP model with the SAAS prior.
273+ class FullyBayesianMultiTaskGP (MultiTaskGP ):
274+ r"""A fully Bayesian multi-task GP model.
275+
273276 This model assumes that the inputs have been normalized to [0, 1]^d and that the
274277 output has been stratified standardized to have zero mean and unit variance for
275- each task. The SAAS model [Eriksson2021saasbo]_ with a Matern-5/2 is used as data
276- kernel by default.
278+ each task.
277279
278280 You are expected to use ``fit_fully_bayesian_model_nuts`` to fit this model as it
279281 isn't compatible with ``fit_gpytorch_mll``.
@@ -286,11 +288,12 @@ class SaasFullyBayesianMultiTaskGP(MultiTaskGP):
286288 >>> ])
287289 >>> train_Y = torch.cat(f1(X1), f2(X2)).unsqueeze(-1)
288290 >>> train_Yvar = 0.01 * torch.ones_like(train_Y)
289- >>> mtsaas_gp = SaasFullyBayesianMultiTaskGP(
290- >>> train_X, train_Y, train_Yvar, task_feature=-1,
291+ >>> mt_gp = FullyBayesianMultiTaskGP(
292+ >>> train_X, train_Y, task_feature=-1,
293+ >>> pyro_model=MultitaskSaasPyroModel(),
291294 >>> )
292- >>> fit_fully_bayesian_model_nuts(mtsaas_gp )
293- >>> posterior = mtsaas_gp .posterior(test_X)
295+ >>> fit_fully_bayesian_model_nuts(mt_gp )
296+ >>> posterior = mt_gp .posterior(test_X)
294297 """
295298
296299 _is_fully_bayesian = True
@@ -307,7 +310,7 @@ def __init__(
307310 all_tasks : list [int ] | None = None ,
308311 outcome_transform : OutcomeTransform | None = None ,
309312 input_transform : InputTransform | None = None ,
310- pyro_model : MultitaskSaasPyroModel | None = None ,
313+ pyro_model : PyroModel | None = None ,
311314 validate_task_values : bool = True ,
312315 ) -> None :
313316 r"""Initialize the fully Bayesian multi-task GP model.
@@ -334,8 +337,7 @@ def __init__(
334337 instantiation of the model.
335338 input_transform: An input transform that is applied to the inputs ``X``
336339 in the model's forward pass.
337- pyro_model: Optional ``PyroModel`` that has the same signature as
338- ``MultitaskSaasPyroModel``. Defaults to ``MultitaskSaasPyroModel``.
340+ pyro_model: A ``PyroModel`` that inherits from ``MultiTaskPyroMixin``.
339341 validate_task_values: If True, validate that the task values supplied in the
340342 input are expected tasks values. If false, unexpected task values
341343 will be mapped to the first output_task if supplied.
@@ -385,7 +387,8 @@ def __init__(
385387 self .likelihood = None
386388 if pyro_model is None :
387389 pyro_model = MultitaskSaasPyroModel ()
388- # apply task_mapper
390+ if not isinstance (pyro_model , MultiTaskPyroMixin ):
391+ raise ValueError ("pyro_model must be a multi-task model." )
389392 x_before , task_idcs , x_after = self ._split_inputs (transformed_X )
390393 pyro_model .set_inputs (
391394 train_X = torch .cat ([x_before , task_idcs , x_after ], dim = - 1 ),
@@ -395,15 +398,13 @@ def __init__(
395398 task_rank = self ._rank ,
396399 all_tasks = all_tasks ,
397400 )
398- self .pyro_model : MultitaskSaasPyroModel = pyro_model
401+ self .pyro_model : PyroModel = pyro_model
399402 if outcome_transform is not None :
400403 self .outcome_transform = outcome_transform
401404 if input_transform is not None :
402405 self .input_transform = input_transform
403406
404- def train (
405- self , mode : bool = True , reset : bool = True
406- ) -> TSaasFullyBayesianMultiTaskGP :
407+ def train (self , mode : bool = True , reset : bool = True ) -> TFullyBayesianMultiTaskGP :
407408 r"""Puts the model in ``train`` mode.
408409
409410 Args:
@@ -437,7 +438,7 @@ def num_mcmc_samples(self) -> int:
437438 @property
438439 def batch_shape (self ) -> torch .Size :
439440 r"""Batch shape of the model, equal to the number of MCMC samples.
440- Note that ``SaasFullyBayesianMultiTaskGP `` does not support batching
441+ Note that ``FullyBayesianMultiTaskGP `` does not support batching
441442 over input data at this point.
442443 """
443444 self ._check_if_fitted ()
@@ -514,22 +515,14 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
514515 r"""Custom logic for loading the state dict.
515516
516517 The standard approach of calling ``load_state_dict`` currently doesn't
517- play well with the ``SaasFullyBayesianMultiTaskGP `` since the
518+ play well with the ``FullyBayesianMultiTaskGP `` since the
518519 ``mean_module``, ``covar_module`` and ``likelihood`` aren't initialized
519520 until the model has been fitted. The reason for this is that we don't
520521 know the number of MCMC samples until NUTS is called. Given the state
521522 dict, we can initialize a new model with some dummy samples and then
522- load the state dict into this model. This currently only works for a
523- ``MultitaskSaasPyroModel`` and supporting more Pyro models likely
524- requires moving the model construction logic into the Pyro model itself.
525-
526- TODO: If this were to inherit from ``SaasFullyBayesianSingleTaskGP``, we could
527- simplify this method and eliminate some others.
523+ load the state dict into this model. The dummy samples are obtained
524+ from ``pyro_model.get_dummy_mcmc_samples()``.
528525 """
529- if not isinstance (self .pyro_model , MultitaskSaasPyroModel ):
530- raise NotImplementedError ( # pragma: no cover
531- "load_state_dict only works for MultitaskSaasPyroModel"
532- )
533526 raw_mean = state_dict ["mean_module.base_means.0.raw_constant" ]
534527 num_mcmc_samples = len (raw_mean )
535528 tkwargs = {"device" : raw_mean .device , "dtype" : raw_mean .dtype }
@@ -573,3 +566,7 @@ def condition_on_observations(
573566 X = X .repeat (* (Y .shape [:- 2 ] + (1 , 1 )))
574567
575568 return super ().condition_on_observations (X , Y , ** kwargs )
569+
570+
571+ class SaasFullyBayesianMultiTaskGP (FullyBayesianMultiTaskGP ):
572+ r"""A fully Bayesian multi-task GP model with the SAAS prior by default."""
0 commit comments