1616 matern52_kernel ,
1717 MCMC_DIM ,
1818 MIN_INFERRED_NOISE_LEVEL ,
19+ PyroModel ,
1920 reshape_and_detach ,
2021 SaasPyroModel ,
2122)
3940from typing_extensions import Self
4041
4142# Can replace with Self type once 3.11 is the minimum version
42- TSaasFullyBayesianMultiTaskGP = TypeVar (
43- "TSaasFullyBayesianMultiTaskGP " , bound = "SaasFullyBayesianMultiTaskGP "
43+ TFullyBayesianMultiTaskGP = TypeVar (
44+ "TFullyBayesianMultiTaskGP " , bound = "FullyBayesianMultiTaskGP "
4445)
4546
4647
@@ -260,19 +261,20 @@ def get_dummy_mcmc_samples(
260261
261262class MultitaskSaasPyroModel (LatentFeatureMultiTaskPyroMixin , SaasPyroModel ):
262263 r"""
263- Multi-task SAAS model. Backward-compatible subclass that composes
264- ``LatentFeatureMultiTaskPyroMixin`` with ``SaasPyroModel``.
264+ Multi-task SAAS model using latent task features. Backward-compatible
265+ subclass that composes ``LatentFeatureMultiTaskPyroMixin`` with
266+ ``SaasPyroModel``.
265267 """
266268
267269 pass
268270
269271
270- class SaasFullyBayesianMultiTaskGP (MultiTaskGP ):
271- r"""A fully Bayesian multi-task GP model with the SAAS prior.
272+ class FullyBayesianMultiTaskGP (MultiTaskGP ):
273+ r"""A fully Bayesian multi-task GP model.
274+
272275 This model assumes that the inputs have been normalized to [0, 1]^d and that the
273276 output has been stratified standardized to have zero mean and unit variance for
274- each task. The SAAS model [Eriksson2021saasbo]_ with a Matern-5/2 is used as data
275- kernel by default.
277+ each task.
276278
277279 You are expected to use ``fit_fully_bayesian_model_nuts`` to fit this model as it
278280 isn't compatible with ``fit_gpytorch_mll``.
@@ -285,11 +287,12 @@ class SaasFullyBayesianMultiTaskGP(MultiTaskGP):
285287 >>> ])
286288 >>> train_Y = torch.cat(f1(X1), f2(X2)).unsqueeze(-1)
287289 >>> train_Yvar = 0.01 * torch.ones_like(train_Y)
288- >>> mtsaas_gp = SaasFullyBayesianMultiTaskGP(
289- >>> train_X, train_Y, train_Yvar, task_feature=-1,
290+ >>> mt_gp = FullyBayesianMultiTaskGP(
291+ >>> train_X, train_Y, task_feature=-1,
292+ >>> pyro_model=MultitaskSaasPyroModel(),
290293 >>> )
291- >>> fit_fully_bayesian_model_nuts(mtsaas_gp )
292- >>> posterior = mtsaas_gp .posterior(test_X)
294+ >>> fit_fully_bayesian_model_nuts(mt_gp )
295+ >>> posterior = mt_gp .posterior(test_X)
293296 """
294297
295298 _is_fully_bayesian = True
@@ -306,7 +309,7 @@ def __init__(
306309 all_tasks : list [int ] | None = None ,
307310 outcome_transform : OutcomeTransform | None = None ,
308311 input_transform : InputTransform | None = None ,
309- pyro_model : MultitaskSaasPyroModel | None = None ,
312+ pyro_model : PyroModel | None = None ,
310313 validate_task_values : bool = True ,
311314 ) -> None :
312315 r"""Initialize the fully Bayesian multi-task GP model.
@@ -333,8 +336,7 @@ def __init__(
333336 instantiation of the model.
334337 input_transform: An input transform that is applied to the inputs ``X``
335338 in the model's forward pass.
336- pyro_model: Optional ``PyroModel`` that has the same signature as
337- ``MultitaskSaasPyroModel``. Defaults to ``MultitaskSaasPyroModel``.
339+ pyro_model: A ``PyroModel`` that inherits from ``MultiTaskPyroMixin``.
338340 validate_task_values: If True, validate that the task values supplied in the
339341 input are expected tasks values. If false, unexpected task values
340342 will be mapped to the first output_task if supplied.
@@ -384,7 +386,8 @@ def __init__(
384386 self .likelihood = None
385387 if pyro_model is None :
386388 pyro_model = MultitaskSaasPyroModel ()
387- # apply task_mapper
389+ if not isinstance (pyro_model , MultiTaskPyroMixin ):
390+ raise ValueError ("pyro_model must be a multi-task model." )
388391 x_before , task_idcs , x_after = self ._split_inputs (transformed_X )
389392 pyro_model .set_inputs (
390393 train_X = torch .cat ([x_before , task_idcs , x_after ], dim = - 1 ),
@@ -394,15 +397,13 @@ def __init__(
394397 task_rank = self ._rank ,
395398 all_tasks = all_tasks ,
396399 )
397- self .pyro_model : MultitaskSaasPyroModel = pyro_model
400+ self .pyro_model : PyroModel = pyro_model
398401 if outcome_transform is not None :
399402 self .outcome_transform = outcome_transform
400403 if input_transform is not None :
401404 self .input_transform = input_transform
402405
403- def train (
404- self , mode : bool = True , reset : bool = True
405- ) -> TSaasFullyBayesianMultiTaskGP :
406+ def train (self , mode : bool = True , reset : bool = True ) -> TFullyBayesianMultiTaskGP :
406407 r"""Puts the model in ``train`` mode.
407408
408409 Args:
@@ -436,7 +437,7 @@ def num_mcmc_samples(self) -> int:
436437 @property
437438 def batch_shape (self ) -> torch .Size :
438439 r"""Batch shape of the model, equal to the number of MCMC samples.
439- Note that ``SaasFullyBayesianMultiTaskGP `` does not support batching
440+ Note that ``FullyBayesianMultiTaskGP `` does not support batching
440441 over input data at this point.
441442 """
442443 self ._check_if_fitted ()
@@ -513,22 +514,14 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
513514 r"""Custom logic for loading the state dict.
514515
515516 The standard approach of calling ``load_state_dict`` currently doesn't
516- play well with the ``SaasFullyBayesianMultiTaskGP `` since the
517+ play well with the ``FullyBayesianMultiTaskGP `` since the
517518 ``mean_module``, ``covar_module`` and ``likelihood`` aren't initialized
518519 until the model has been fitted. The reason for this is that we don't
519520 know the number of MCMC samples until NUTS is called. Given the state
520521 dict, we can initialize a new model with some dummy samples and then
521- load the state dict into this model. This currently only works for a
522- ``MultitaskSaasPyroModel`` and supporting more Pyro models likely
523- requires moving the model construction logic into the Pyro model itself.
524-
525- TODO: If this were to inherit from ``SaasFullyBayesianSingleTaskGP``, we could
526- simplify this method and eliminate some others.
522+ load the state dict into this model. The dummy samples are obtained
523+ from ``pyro_model.get_dummy_mcmc_samples()``.
527524 """
528- if not isinstance (self .pyro_model , MultitaskSaasPyroModel ):
529- raise NotImplementedError ( # pragma: no cover
530- "load_state_dict only works for MultitaskSaasPyroModel"
531- )
532525 raw_mean = state_dict ["mean_module.base_means.0.raw_constant" ]
533526 num_mcmc_samples = len (raw_mean )
534527 tkwargs = {"device" : raw_mean .device , "dtype" : raw_mean .dtype }
@@ -572,3 +565,7 @@ def condition_on_observations(
572565 X = X .repeat (* (Y .shape [:- 2 ] + (1 , 1 )))
573566
574567 return super ().condition_on_observations (X , Y , ** kwargs )
568+
569+
570+ class SaasFullyBayesianMultiTaskGP (FullyBayesianMultiTaskGP ):
571+ r"""A fully Bayesian multi-task GP model with the SAAS prior by default."""
0 commit comments