@@ -165,14 +165,23 @@ def _maybe_input_warp(self, X: Tensor, **tkwargs: Any) -> Tensor:
165165 return self .train_X
166166
167167 def set_inputs (
168- self , train_X : Tensor , train_Y : Tensor , train_Yvar : Tensor | None = None
168+ self ,
169+ train_X : Tensor ,
170+ train_Y : Tensor ,
171+ train_Yvar : Tensor | None = None ,
172+ task_feature : int | None = None ,
173+ task_rank : int | None = None ,
169174 ) -> None :
170175 """Set the training data.
171176
172177 Args:
173178 train_X: Training inputs (n x d)
174179 train_Y: Training targets (n x 1)
175180 train_Yvar: Observed noise variance (n x 1). Inferred if None.
181+ task_feature: The index of the task feature column. Used by
182+ multi-task mixins.
183+ task_rank: The number of learned task embeddings. Used by
184+ multi-task mixins.
176185 """
177186 self .train_X = train_X
178187 self .train_Y = train_Y
@@ -192,6 +201,19 @@ def postprocess_mcmc_samples(
192201 """Post-process the final MCMC samples."""
193202 pass # pragma: no cover
194203
204+ @abstractmethod
205+ def get_dummy_mcmc_samples (
206+ self ,
207+ num_mcmc_samples : int ,
208+ ** tkwargs : Any ,
209+ ) -> dict [str , Tensor ]:
210+ """Return dummy MCMC samples for state dict loading.
211+
212+ Each subclass returns a dict of ones with the keys and shapes that
213+ ``load_mcmc_samples`` expects.
214+ """
215+ pass # pragma: no cover
216+
195217 @abstractmethod
196218 def load_mcmc_samples (
197219 self , mcmc_samples : dict [str , Tensor ]
@@ -245,6 +267,60 @@ def sample_concentrations(self, **tkwargs: Any) -> tuple[Tensor, Tensor]:
245267
246268 return c0 , c1
247269
270+ def _common_dummy_samples (
271+ self ,
272+ mcmc_samples : dict [str , Tensor ],
273+ num_mcmc_samples : int ,
274+ ** tkwargs : Any ,
275+ ) -> dict [str , Tensor ]:
276+ """Add noise and warping entries to ``mcmc_samples`` in-place."""
277+ dim = self .ard_num_dims
278+ if self .train_Yvar is None :
279+ mcmc_samples ["noise" ] = torch .ones (num_mcmc_samples , ** tkwargs )
280+ if self .use_input_warping :
281+ mcmc_samples ["c0" ] = torch .ones (num_mcmc_samples , dim , ** tkwargs )
282+ mcmc_samples ["c1" ] = torch .ones (num_mcmc_samples , dim , ** tkwargs )
283+ return mcmc_samples
284+
285+ def _prepare_features (self , X : Tensor , ** tkwargs : Any ) -> Tensor :
286+ """Select feature columns for kernel computation.
287+
288+ Overridden by multi-task mixins to strip the task column.
289+ """
290+ return X
291+
292+ def _maybe_multitask_transform (
293+ self , K_noiseless : Tensor , mean : Tensor , ** tkwargs : Any
294+ ) -> tuple [Tensor , Tensor ]:
295+ r"""Apply multi-task covariance transform to the kernel and mean.
296+
297+ No-op for single-task models. Overridden by multi-task mixins.
298+ """
299+ return K_noiseless , mean
300+
301+ def _preprocess_mean_for_loading (
302+ self , mcmc_samples : dict [str , Tensor ]
303+ ) -> dict [str , Tensor ]:
304+ """Return mcmc_samples with a scalar mean for building covariance modules.
305+
306+ No-op for single-task models. Overridden by multi-task mixins.
307+ """
308+ return mcmc_samples
309+
310+ def _finalize_multitask_modules (
311+ self ,
312+ mcmc_samples : dict [str , Tensor ],
313+ mean_module : Mean ,
314+ covar_module : Kernel ,
315+ batch_shape : torch .Size ,
316+ ** tkwargs : Any ,
317+ ) -> tuple [Mean , Kernel ]:
318+ """Optionally replace mean/covar with multi-task versions.
319+
320+ No-op for single-task models. Overridden by multi-task mixins.
321+ """
322+ return mean_module , covar_module
323+
248324 def sample_observations (
249325 self ,
250326 mean : Tensor ,
@@ -322,7 +398,11 @@ def sample(self) -> None:
322398 noise = self .sample_noise (** tkwargs )
323399 lengthscale = self .sample_lengthscale (dim = self .ard_num_dims , ** tkwargs )
324400 X_tf = self ._maybe_input_warp (self .train_X , ** tkwargs )
401+ X_tf = self ._prepare_features (X_tf , ** tkwargs )
325402 K_noiseless = outputscale * matern52_kernel (X = X_tf , lengthscale = lengthscale )
403+ K_noiseless , mean = self ._maybe_multitask_transform (
404+ K_noiseless , mean , ** tkwargs
405+ )
326406 self .sample_observations (
327407 mean = mean , K_noiseless = K_noiseless , noise = noise , ** tkwargs
328408 )
@@ -375,6 +455,18 @@ def postprocess_mcmc_samples(
375455 """
376456 return mcmc_samples
377457
458+ def get_dummy_mcmc_samples (
459+ self ,
460+ num_mcmc_samples : int ,
461+ ** tkwargs : Any ,
462+ ) -> dict [str , Tensor ]:
463+ """Return dummy MCMC samples for state dict loading."""
464+ mcmc_samples = {
465+ "mean" : torch .ones (num_mcmc_samples , ** tkwargs ),
466+ "lengthscale" : torch .ones (num_mcmc_samples , self .ard_num_dims , ** tkwargs ),
467+ }
468+ return self ._common_dummy_samples (mcmc_samples , num_mcmc_samples , ** tkwargs )
469+
378470 def _get_covar_module (
379471 self ,
380472 use_scale_kernel : bool ,
@@ -410,8 +502,10 @@ def load_mcmc_samples(
410502 num_mcmc_samples = len (mcmc_samples ["mean" ])
411503 batch_shape = torch .Size ([num_mcmc_samples ])
412504
505+ mean_mcmc = self ._preprocess_mean_for_loading (mcmc_samples )
506+
413507 mean_module = ConstantMean (batch_shape = batch_shape ).to (** tkwargs )
414- outputscale = mcmc_samples .get ("outputscale" )
508+ outputscale = mean_mcmc .get ("outputscale" )
415509 covar_module = self ._get_covar_module (
416510 use_scale_kernel = outputscale is not None , batch_shape = batch_shape , ** tkwargs
417511 )
@@ -446,7 +540,7 @@ def load_mcmc_samples(
446540 )
447541 mean_module .constant .data = reshape_and_detach (
448542 target = mean_module .constant .data ,
449- new_value = mcmc_samples ["mean" ],
543+ new_value = mean_mcmc ["mean" ],
450544 )
451545 if self .use_input_warping :
452546 indices = (
@@ -470,6 +564,14 @@ def load_mcmc_samples(
470564 )
471565 else :
472566 warping_function = None
567+
568+ mean_module , covar_module = self ._finalize_multitask_modules (
569+ mcmc_samples = mcmc_samples ,
570+ mean_module = mean_module ,
571+ covar_module = covar_module ,
572+ batch_shape = batch_shape ,
573+ ** tkwargs ,
574+ )
473575 return mean_module , covar_module , likelihood , warping_function
474576
475577
@@ -529,6 +631,18 @@ def postprocess_mcmc_samples(
529631 del mcmc_samples ["kernel_tausq" ], mcmc_samples ["_kernel_inv_length_sq" ]
530632 return mcmc_samples
531633
634+ def get_dummy_mcmc_samples (
635+ self ,
636+ num_mcmc_samples : int ,
637+ ** tkwargs : Any ,
638+ ) -> dict [str , Tensor ]:
639+ """Return dummy MCMC samples for state dict loading."""
640+ mcmc_samples = super ().get_dummy_mcmc_samples (
641+ num_mcmc_samples = num_mcmc_samples , ** tkwargs
642+ )
643+ mcmc_samples ["outputscale" ] = torch .ones (num_mcmc_samples , ** tkwargs )
644+ return mcmc_samples
645+
532646
533647class LinearPyroModel (PyroModel ):
534648 r"""Implementation of a Bayesian Linear pyro model.
@@ -546,9 +660,13 @@ def sample(self) -> None:
546660 mean = self .sample_mean (** tkwargs )
547661 weight_variance = self .sample_weight_variance (** tkwargs )
548662 X_tf = self ._maybe_input_warp (X = self .train_X , ** tkwargs )
663+ X_tf = self ._prepare_features (X_tf , ** tkwargs )
549664 X_tf = X_tf - 0.5 # center transformed data at 0 (for linear model)
550665 K_noiseless = linear_kernel (X = X_tf , weight_variance = weight_variance )
551666 noise = self .sample_noise (** tkwargs )
667+ K_noiseless , mean = self ._maybe_multitask_transform (
668+ K_noiseless , mean , ** tkwargs
669+ )
552670 self .sample_observations (
553671 mean = mean , K_noiseless = K_noiseless , noise = noise , ** tkwargs
554672 )
@@ -589,6 +707,20 @@ def postprocess_mcmc_samples(
589707 del mcmc_samples ["tau_sq" ], mcmc_samples ["_weight_variance_sq" ]
590708 return mcmc_samples
591709
710+ def get_dummy_mcmc_samples (
711+ self ,
712+ num_mcmc_samples : int ,
713+ ** tkwargs : Any ,
714+ ) -> dict [str , Tensor ]:
715+ """Return dummy MCMC samples for state dict loading."""
716+ mcmc_samples = {
717+ "mean" : torch .ones (num_mcmc_samples , ** tkwargs ),
718+ "weight_variance" : torch .ones (
719+ num_mcmc_samples , self .ard_num_dims , ** tkwargs
720+ ),
721+ }
722+ return self ._common_dummy_samples (mcmc_samples , num_mcmc_samples , ** tkwargs )
723+
592724 def load_mcmc_samples (
593725 self , mcmc_samples : dict [str , Tensor ]
594726 ) -> tuple [Mean , Kernel , Likelihood , InputTransform ]:
@@ -975,27 +1107,6 @@ def median_lengthscale(self) -> Tensor:
9751107 lengthscale = base_kernel .lengthscale .clone ()
9761108 return lengthscale .median (0 ).values .squeeze (0 )
9771109
978- def _get_dummy_mcmc_samples (
979- self ,
980- num_mcmc_samples : int ,
981- dim : int ,
982- dtype : torch .dtype ,
983- device : torch .device ,
984- ) -> dict [str , Tensor ]:
985- # Load some dummy samples
986- tkwargs = {"dtype" : dtype , "device" : device }
987- mcmc_samples = {
988- "mean" : torch .ones (num_mcmc_samples , ** tkwargs ),
989- "lengthscale" : torch .ones (num_mcmc_samples , dim , ** tkwargs ),
990- }
991- if self .pyro_model .train_Yvar is None :
992- mcmc_samples ["noise" ] = torch .ones (num_mcmc_samples , ** tkwargs )
993-
994- if self .pyro_model .use_input_warping :
995- mcmc_samples ["c0" ] = torch .ones (num_mcmc_samples , dim , ** tkwargs )
996- mcmc_samples ["c1" ] = torch .ones (num_mcmc_samples , dim , ** tkwargs )
997- return mcmc_samples
998-
9991110 def load_state_dict (
10001111 self , state_dict : Mapping [str , Any ], strict : bool = True
10011112 ) -> None :
@@ -1012,11 +1123,10 @@ def load_state_dict(
10121123 the model construction logic into the Pyro model itself.
10131124 """
10141125 raw_mean = state_dict ["mean_module.raw_constant" ]
1015- mcmc_samples = self ._get_dummy_mcmc_samples (
1016- num_mcmc_samples = len (raw_mean ),
1017- dim = self .pyro_model .train_X .shape [- 1 ],
1018- dtype = raw_mean .dtype ,
1019- device = raw_mean .device ,
1126+ num_mcmc_samples = len (raw_mean )
1127+ tkwargs = {"dtype" : raw_mean .dtype , "device" : raw_mean .device }
1128+ mcmc_samples = self .pyro_model .get_dummy_mcmc_samples (
1129+ num_mcmc_samples = num_mcmc_samples , ** tkwargs
10201130 )
10211131 self .load_mcmc_samples (mcmc_samples = mcmc_samples )
10221132 # Load the actual samples from the state dict
@@ -1043,22 +1153,6 @@ class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP):
10431153
10441154 _pyro_model_class : type [PyroModel ] = SaasPyroModel
10451155
1046- def _get_dummy_mcmc_samples (
1047- self ,
1048- num_mcmc_samples : int ,
1049- dim : int ,
1050- dtype : torch .dtype ,
1051- device : torch .device ,
1052- ) -> dict [str , Tensor ]:
1053- mcmc_samples = super ()._get_dummy_mcmc_samples (
1054- num_mcmc_samples = num_mcmc_samples , dim = dim , dtype = dtype , device = device
1055- )
1056- # add outputscale
1057- mcmc_samples ["outputscale" ] = torch .ones (
1058- num_mcmc_samples , dtype = dtype , device = device
1059- )
1060- return mcmc_samples
1061-
10621156
10631157class FullyBayesianLinearSingleTaskGP (AbstractFullyBayesianSingleTaskGP ):
10641158 r"""A fully Bayesian single-task GP model with a linear kernel.
@@ -1102,19 +1196,10 @@ def load_state_dict(
11021196 """
11031197 weight_variance = state_dict ["covar_module.raw_variance" ]
11041198 num_mcmc_samples = len (weight_variance )
1105- dim = self .pyro_model .train_X .shape [- 1 ]
11061199 tkwargs = {"device" : weight_variance .device , "dtype" : weight_variance .dtype }
1107- # Load some dummy samples
1108- # deal with c0 c1
1109- mcmc_samples = {
1110- "mean" : torch .ones (num_mcmc_samples , ** tkwargs ),
1111- "weight_variance" : torch .ones (num_mcmc_samples , dim , ** tkwargs ),
1112- }
1113- if self .pyro_model .use_input_warping :
1114- mcmc_samples ["c0" ] = torch .ones (num_mcmc_samples , dim , ** tkwargs )
1115- mcmc_samples ["c1" ] = torch .ones (num_mcmc_samples , dim , ** tkwargs )
1116- if self .pyro_model .train_Yvar is None :
1117- mcmc_samples ["noise" ] = torch .ones (num_mcmc_samples , ** tkwargs )
1200+ mcmc_samples = self .pyro_model .get_dummy_mcmc_samples (
1201+ num_mcmc_samples = num_mcmc_samples , ** tkwargs
1202+ )
11181203 self .load_mcmc_samples (mcmc_samples = mcmc_samples )
11191204 # Load the actual samples from the state dict
11201205 super ().load_state_dict (state_dict = state_dict , strict = strict )
0 commit comments