@@ -165,14 +165,23 @@ def _maybe_input_warp(self, X: Tensor, **tkwargs: Any) -> Tensor:
165165 return self .train_X
166166
167167 def set_inputs (
168- self , train_X : Tensor , train_Y : Tensor , train_Yvar : Tensor | None = None
168+ self ,
169+ train_X : Tensor ,
170+ train_Y : Tensor ,
171+ train_Yvar : Tensor | None = None ,
172+ task_feature : int | None = None ,
173+ task_rank : int | None = None ,
169174 ) -> None :
170175 """Set the training data.
171176
172177 Args:
173178 train_X: Training inputs (n x d)
174179 train_Y: Training targets (n x 1)
175180 train_Yvar: Observed noise variance (n x 1). Inferred if None.
181+ task_feature: The index of the task feature column. Used by
182+ multi-task mixins.
183+ task_rank: The number of learned task embeddings. Used by
184+ multi-task mixins.
176185 """
177186 self .train_X = train_X
178187 self .train_Y = train_Y
@@ -192,6 +201,19 @@ def postprocess_mcmc_samples(
192201 """Post-process the final MCMC samples."""
193202 pass # pragma: no cover
194203
204+ @abstractmethod
205+ def get_dummy_mcmc_samples (
206+ self ,
207+ num_mcmc_samples : int ,
208+ ** tkwargs : Any ,
209+ ) -> dict [str , Tensor ]:
210+ """Return dummy MCMC samples for state dict loading.
211+
212+ Each subclass returns a dict of ones with the keys and shapes that
213+ ``load_mcmc_samples`` expects.
214+ """
215+ pass # pragma: no cover
216+
195217 @abstractmethod
196218 def load_mcmc_samples (
197219 self , mcmc_samples : dict [str , Tensor ]
@@ -245,6 +267,68 @@ def sample_concentrations(self, **tkwargs: Any) -> tuple[Tensor, Tensor]:
245267
246268 return c0 , c1
247269
270+ def _common_dummy_samples (
271+ self ,
272+ mcmc_samples : dict [str , Tensor ],
273+ num_mcmc_samples : int ,
274+ ** tkwargs : Any ,
275+ ) -> dict [str , Tensor ]:
276+ """Add noise and warping entries to ``mcmc_samples`` in-place."""
277+ dim = self .ard_num_dims
278+ if self .train_Yvar is None :
279+ mcmc_samples ["noise" ] = torch .ones (num_mcmc_samples , ** tkwargs )
280+ if self .use_input_warping :
281+ mcmc_samples ["c0" ] = torch .ones (num_mcmc_samples , dim , ** tkwargs )
282+ mcmc_samples ["c1" ] = torch .ones (num_mcmc_samples , dim , ** tkwargs )
283+ return mcmc_samples
284+
285+ def _prepare_features (self , X : Tensor , ** tkwargs : Any ) -> Tensor :
286+ """Select feature columns for kernel computation.
287+
288+ Overridden by multi-task mixins to strip the task column.
289+ """
290+ return X
291+
292+ def _maybe_multitask_transform (
293+ self , K_noiseless : Tensor , mean : Tensor , ** tkwargs : Any
294+ ) -> tuple [Tensor , Tensor ]:
295+ r"""Apply multi-task covariance transform to the kernel and mean.
296+
297+ No-op for single-task models. Overridden by multi-task mixins.
298+ """
299+ return K_noiseless , mean
300+
301+ def _build_mean_module (
302+ self ,
303+ mcmc_samples : dict [str , Tensor ],
304+ batch_shape : torch .Size ,
305+ ** tkwargs : Any ,
306+ ) -> Mean :
307+ """Construct and populate the mean module from MCMC samples.
308+
309+ Returns a scalar ``ConstantMean`` for single-task models. Overridden by
310+ multi-task mixins to return a ``MultitaskMean``.
311+ """
312+ mean_module = ConstantMean (batch_shape = batch_shape ).to (** tkwargs )
313+ mean_module .constant .data = reshape_and_detach (
314+ target = mean_module .constant .data ,
315+ new_value = mcmc_samples ["mean" ],
316+ )
317+ return mean_module
318+
319+ def _build_multitask_covariance (
320+ self ,
321+ mcmc_samples : dict [str , Tensor ],
322+ covar_module : Kernel ,
323+ batch_shape : torch .Size ,
324+ ** tkwargs : Any ,
325+ ) -> Kernel :
326+ """Optionally wrap covar_module with task covariance.
327+
328+ No-op for single-task models. Overridden by multi-task mixins.
329+ """
330+ return covar_module
331+
248332 def sample_observations (
249333 self ,
250334 mean : Tensor ,
@@ -322,7 +406,11 @@ def sample(self) -> None:
322406 noise = self .sample_noise (** tkwargs )
323407 lengthscale = self .sample_lengthscale (dim = self .ard_num_dims , ** tkwargs )
324408 X_tf = self ._maybe_input_warp (self .train_X , ** tkwargs )
409+ X_tf = self ._prepare_features (X_tf , ** tkwargs )
325410 K_noiseless = outputscale * matern52_kernel (X = X_tf , lengthscale = lengthscale )
411+ K_noiseless , mean = self ._maybe_multitask_transform (
412+ K_noiseless , mean , ** tkwargs
413+ )
326414 self .sample_observations (
327415 mean = mean , K_noiseless = K_noiseless , noise = noise , ** tkwargs
328416 )
@@ -375,6 +463,18 @@ def postprocess_mcmc_samples(
375463 """
376464 return mcmc_samples
377465
466+ def get_dummy_mcmc_samples (
467+ self ,
468+ num_mcmc_samples : int ,
469+ ** tkwargs : Any ,
470+ ) -> dict [str , Tensor ]:
471+ """Return dummy MCMC samples for state dict loading."""
472+ mcmc_samples = {
473+ "mean" : torch .ones (num_mcmc_samples , ** tkwargs ),
474+ "lengthscale" : torch .ones (num_mcmc_samples , self .ard_num_dims , ** tkwargs ),
475+ }
476+ return self ._common_dummy_samples (mcmc_samples , num_mcmc_samples , ** tkwargs )
477+
378478 def _get_covar_module (
379479 self ,
380480 use_scale_kernel : bool ,
@@ -410,7 +510,9 @@ def load_mcmc_samples(
410510 num_mcmc_samples = len (mcmc_samples ["mean" ])
411511 batch_shape = torch .Size ([num_mcmc_samples ])
412512
413- mean_module = ConstantMean (batch_shape = batch_shape ).to (** tkwargs )
513+ mean_module = self ._build_mean_module (
514+ mcmc_samples = mcmc_samples , batch_shape = batch_shape , ** tkwargs
515+ )
414516 outputscale = mcmc_samples .get ("outputscale" )
415517 covar_module = self ._get_covar_module (
416518 use_scale_kernel = outputscale is not None , batch_shape = batch_shape , ** tkwargs
@@ -444,10 +546,6 @@ def load_mcmc_samples(
444546 target = base_kernel .lengthscale ,
445547 new_value = mcmc_samples ["lengthscale" ],
446548 )
447- mean_module .constant .data = reshape_and_detach (
448- target = mean_module .constant .data ,
449- new_value = mcmc_samples ["mean" ],
450- )
451549 if self .use_input_warping :
452550 indices = (
453551 list (range (self .ard_num_dims )) if self .indices is None else self .indices
@@ -470,6 +568,13 @@ def load_mcmc_samples(
470568 )
471569 else :
472570 warping_function = None
571+
572+ covar_module = self ._build_multitask_covariance (
573+ mcmc_samples = mcmc_samples ,
574+ covar_module = covar_module ,
575+ batch_shape = batch_shape ,
576+ ** tkwargs ,
577+ )
473578 return mean_module , covar_module , likelihood , warping_function
474579
475580
@@ -529,6 +634,18 @@ def postprocess_mcmc_samples(
529634 del mcmc_samples ["kernel_tausq" ], mcmc_samples ["_kernel_inv_length_sq" ]
530635 return mcmc_samples
531636
637+ def get_dummy_mcmc_samples (
638+ self ,
639+ num_mcmc_samples : int ,
640+ ** tkwargs : Any ,
641+ ) -> dict [str , Tensor ]:
642+ """Return dummy MCMC samples for state dict loading."""
643+ mcmc_samples = super ().get_dummy_mcmc_samples (
644+ num_mcmc_samples = num_mcmc_samples , ** tkwargs
645+ )
646+ mcmc_samples ["outputscale" ] = torch .ones (num_mcmc_samples , ** tkwargs )
647+ return mcmc_samples
648+
532649
533650class LinearPyroModel (PyroModel ):
534651 r"""Implementation of a Bayesian Linear pyro model.
@@ -546,9 +663,13 @@ def sample(self) -> None:
546663 mean = self .sample_mean (** tkwargs )
547664 weight_variance = self .sample_weight_variance (** tkwargs )
548665 X_tf = self ._maybe_input_warp (X = self .train_X , ** tkwargs )
666+ X_tf = self ._prepare_features (X_tf , ** tkwargs )
549667 X_tf = X_tf - 0.5 # center transformed data at 0 (for linear model)
550668 K_noiseless = linear_kernel (X = X_tf , weight_variance = weight_variance )
551669 noise = self .sample_noise (** tkwargs )
670+ K_noiseless , mean = self ._maybe_multitask_transform (
671+ K_noiseless , mean , ** tkwargs
672+ )
552673 self .sample_observations (
553674 mean = mean , K_noiseless = K_noiseless , noise = noise , ** tkwargs
554675 )
@@ -589,6 +710,20 @@ def postprocess_mcmc_samples(
589710 del mcmc_samples ["tau_sq" ], mcmc_samples ["_weight_variance_sq" ]
590711 return mcmc_samples
591712
713+ def get_dummy_mcmc_samples (
714+ self ,
715+ num_mcmc_samples : int ,
716+ ** tkwargs : Any ,
717+ ) -> dict [str , Tensor ]:
718+ """Return dummy MCMC samples for state dict loading."""
719+ mcmc_samples = {
720+ "mean" : torch .ones (num_mcmc_samples , ** tkwargs ),
721+ "weight_variance" : torch .ones (
722+ num_mcmc_samples , self .ard_num_dims , ** tkwargs
723+ ),
724+ }
725+ return self ._common_dummy_samples (mcmc_samples , num_mcmc_samples , ** tkwargs )
726+
592727 def load_mcmc_samples (
593728 self , mcmc_samples : dict [str , Tensor ]
594729 ) -> tuple [Mean , Kernel , Likelihood , InputTransform ]:
@@ -975,27 +1110,6 @@ def median_lengthscale(self) -> Tensor:
9751110 lengthscale = base_kernel .lengthscale .clone ()
9761111 return lengthscale .median (0 ).values .squeeze (0 )
9771112
978- def _get_dummy_mcmc_samples (
979- self ,
980- num_mcmc_samples : int ,
981- dim : int ,
982- dtype : torch .dtype ,
983- device : torch .device ,
984- ) -> dict [str , Tensor ]:
985- # Load some dummy samples
986- tkwargs = {"dtype" : dtype , "device" : device }
987- mcmc_samples = {
988- "mean" : torch .ones (num_mcmc_samples , ** tkwargs ),
989- "lengthscale" : torch .ones (num_mcmc_samples , dim , ** tkwargs ),
990- }
991- if self .pyro_model .train_Yvar is None :
992- mcmc_samples ["noise" ] = torch .ones (num_mcmc_samples , ** tkwargs )
993-
994- if self .pyro_model .use_input_warping :
995- mcmc_samples ["c0" ] = torch .ones (num_mcmc_samples , dim , ** tkwargs )
996- mcmc_samples ["c1" ] = torch .ones (num_mcmc_samples , dim , ** tkwargs )
997- return mcmc_samples
998-
9991113 def load_state_dict (
10001114 self , state_dict : Mapping [str , Any ], strict : bool = True
10011115 ) -> None :
@@ -1012,11 +1126,10 @@ def load_state_dict(
10121126 the model construction logic into the Pyro model itself.
10131127 """
10141128 raw_mean = state_dict ["mean_module.raw_constant" ]
1015- mcmc_samples = self ._get_dummy_mcmc_samples (
1016- num_mcmc_samples = len (raw_mean ),
1017- dim = self .pyro_model .train_X .shape [- 1 ],
1018- dtype = raw_mean .dtype ,
1019- device = raw_mean .device ,
1129+ num_mcmc_samples = len (raw_mean )
1130+ tkwargs = {"dtype" : raw_mean .dtype , "device" : raw_mean .device }
1131+ mcmc_samples = self .pyro_model .get_dummy_mcmc_samples (
1132+ num_mcmc_samples = num_mcmc_samples , ** tkwargs
10201133 )
10211134 self .load_mcmc_samples (mcmc_samples = mcmc_samples )
10221135 # Load the actual samples from the state dict
@@ -1043,22 +1156,6 @@ class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP):
10431156
10441157 _pyro_model_class : type [PyroModel ] = SaasPyroModel
10451158
1046- def _get_dummy_mcmc_samples (
1047- self ,
1048- num_mcmc_samples : int ,
1049- dim : int ,
1050- dtype : torch .dtype ,
1051- device : torch .device ,
1052- ) -> dict [str , Tensor ]:
1053- mcmc_samples = super ()._get_dummy_mcmc_samples (
1054- num_mcmc_samples = num_mcmc_samples , dim = dim , dtype = dtype , device = device
1055- )
1056- # add outputscale
1057- mcmc_samples ["outputscale" ] = torch .ones (
1058- num_mcmc_samples , dtype = dtype , device = device
1059- )
1060- return mcmc_samples
1061-
10621159
10631160class FullyBayesianLinearSingleTaskGP (AbstractFullyBayesianSingleTaskGP ):
10641161 r"""A fully Bayesian single-task GP model with a linear kernel.
@@ -1102,19 +1199,10 @@ def load_state_dict(
11021199 """
11031200 weight_variance = state_dict ["covar_module.raw_variance" ]
11041201 num_mcmc_samples = len (weight_variance )
1105- dim = self .pyro_model .train_X .shape [- 1 ]
11061202 tkwargs = {"device" : weight_variance .device , "dtype" : weight_variance .dtype }
1107- # Load some dummy samples
1108- # deal with c0 c1
1109- mcmc_samples = {
1110- "mean" : torch .ones (num_mcmc_samples , ** tkwargs ),
1111- "weight_variance" : torch .ones (num_mcmc_samples , dim , ** tkwargs ),
1112- }
1113- if self .pyro_model .use_input_warping :
1114- mcmc_samples ["c0" ] = torch .ones (num_mcmc_samples , dim , ** tkwargs )
1115- mcmc_samples ["c1" ] = torch .ones (num_mcmc_samples , dim , ** tkwargs )
1116- if self .pyro_model .train_Yvar is None :
1117- mcmc_samples ["noise" ] = torch .ones (num_mcmc_samples , ** tkwargs )
1203+ mcmc_samples = self .pyro_model .get_dummy_mcmc_samples (
1204+ num_mcmc_samples = num_mcmc_samples , ** tkwargs
1205+ )
11181206 self .load_mcmc_samples (mcmc_samples = mcmc_samples )
11191207 # Load the actual samples from the state dict
11201208 super ().load_state_dict (state_dict = state_dict , strict = strict )
0 commit comments