Skip to content

Commit 81ff5f9

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Add multi-task mixin support to PyroModel hierarchy
Summary: Generalize PyroModel with dispatch methods and multi-task mixins so that any PyroModel subclass can be composed with multi-task capabilities.. Differential Revision: D92844567
1 parent 7b5e35b commit 81ff5f9

4 files changed

Lines changed: 538 additions & 197 deletions

File tree

botorch/models/fully_bayesian.py

Lines changed: 142 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,23 @@ def _maybe_input_warp(self, X: Tensor, **tkwargs: Any) -> Tensor:
165165
return self.train_X
166166

167167
def set_inputs(
168-
self, train_X: Tensor, train_Y: Tensor, train_Yvar: Tensor | None = None
168+
self,
169+
train_X: Tensor,
170+
train_Y: Tensor,
171+
train_Yvar: Tensor | None = None,
172+
task_feature: int | None = None,
173+
task_rank: int | None = None,
169174
) -> None:
170175
"""Set the training data.
171176
172177
Args:
173178
train_X: Training inputs (n x d)
174179
train_Y: Training targets (n x 1)
175180
train_Yvar: Observed noise variance (n x 1). Inferred if None.
181+
task_feature: The index of the task feature column. Used by
182+
multi-task mixins.
183+
task_rank: The number of learned task embeddings. Used by
184+
multi-task mixins.
176185
"""
177186
self.train_X = train_X
178187
self.train_Y = train_Y
@@ -192,6 +201,19 @@ def postprocess_mcmc_samples(
192201
"""Post-process the final MCMC samples."""
193202
pass # pragma: no cover
194203

204+
@abstractmethod
205+
def get_dummy_mcmc_samples(
206+
self,
207+
num_mcmc_samples: int,
208+
**tkwargs: Any,
209+
) -> dict[str, Tensor]:
210+
"""Return dummy MCMC samples for state dict loading.
211+
212+
Each subclass returns a dict of ones with the keys and shapes that
213+
``load_mcmc_samples`` expects.
214+
"""
215+
pass # pragma: no cover
216+
195217
@abstractmethod
196218
def load_mcmc_samples(
197219
self, mcmc_samples: dict[str, Tensor]
@@ -245,6 +267,60 @@ def sample_concentrations(self, **tkwargs: Any) -> tuple[Tensor, Tensor]:
245267

246268
return c0, c1
247269

270+
def _common_dummy_samples(
271+
self,
272+
mcmc_samples: dict[str, Tensor],
273+
num_mcmc_samples: int,
274+
**tkwargs: Any,
275+
) -> dict[str, Tensor]:
276+
"""Add noise and warping entries to ``mcmc_samples`` in-place."""
277+
dim = self.ard_num_dims
278+
if self.train_Yvar is None:
279+
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
280+
if self.use_input_warping:
281+
mcmc_samples["c0"] = torch.ones(num_mcmc_samples, dim, **tkwargs)
282+
mcmc_samples["c1"] = torch.ones(num_mcmc_samples, dim, **tkwargs)
283+
return mcmc_samples
284+
285+
def _prepare_features(self, X: Tensor, **tkwargs: Any) -> Tensor:
286+
"""Select feature columns for kernel computation.
287+
288+
Overridden by multi-task mixins to strip the task column.
289+
"""
290+
return X
291+
292+
def _maybe_multitask_transform(
293+
self, K_noiseless: Tensor, mean: Tensor, **tkwargs: Any
294+
) -> tuple[Tensor, Tensor]:
295+
r"""Apply multi-task covariance transform to the kernel and mean.
296+
297+
No-op for single-task models. Overridden by multi-task mixins.
298+
"""
299+
return K_noiseless, mean
300+
301+
def _preprocess_mean_for_loading(
302+
self, mcmc_samples: dict[str, Tensor]
303+
) -> dict[str, Tensor]:
304+
"""Return mcmc_samples with a scalar mean for building covariance modules.
305+
306+
No-op for single-task models. Overridden by multi-task mixins.
307+
"""
308+
return mcmc_samples
309+
310+
def _finalize_multitask_modules(
311+
self,
312+
mcmc_samples: dict[str, Tensor],
313+
mean_module: Mean,
314+
covar_module: Kernel,
315+
batch_shape: torch.Size,
316+
**tkwargs: Any,
317+
) -> tuple[Mean, Kernel]:
318+
"""Optionally replace mean/covar with multi-task versions.
319+
320+
No-op for single-task models. Overridden by multi-task mixins.
321+
"""
322+
return mean_module, covar_module
323+
248324
def sample_observations(
249325
self,
250326
mean: Tensor,
@@ -322,7 +398,11 @@ def sample(self) -> None:
322398
noise = self.sample_noise(**tkwargs)
323399
lengthscale = self.sample_lengthscale(dim=self.ard_num_dims, **tkwargs)
324400
X_tf = self._maybe_input_warp(self.train_X, **tkwargs)
401+
X_tf = self._prepare_features(X_tf, **tkwargs)
325402
K_noiseless = outputscale * matern52_kernel(X=X_tf, lengthscale=lengthscale)
403+
K_noiseless, mean = self._maybe_multitask_transform(
404+
K_noiseless, mean, **tkwargs
405+
)
326406
self.sample_observations(
327407
mean=mean, K_noiseless=K_noiseless, noise=noise, **tkwargs
328408
)
@@ -375,6 +455,18 @@ def postprocess_mcmc_samples(
375455
"""
376456
return mcmc_samples
377457

458+
def get_dummy_mcmc_samples(
459+
self,
460+
num_mcmc_samples: int,
461+
**tkwargs: Any,
462+
) -> dict[str, Tensor]:
463+
"""Return dummy MCMC samples for state dict loading."""
464+
mcmc_samples = {
465+
"mean": torch.ones(num_mcmc_samples, **tkwargs),
466+
"lengthscale": torch.ones(num_mcmc_samples, self.ard_num_dims, **tkwargs),
467+
}
468+
return self._common_dummy_samples(mcmc_samples, num_mcmc_samples, **tkwargs)
469+
378470
def _get_covar_module(
379471
self,
380472
use_scale_kernel: bool,
@@ -410,8 +502,10 @@ def load_mcmc_samples(
410502
num_mcmc_samples = len(mcmc_samples["mean"])
411503
batch_shape = torch.Size([num_mcmc_samples])
412504

505+
mean_mcmc = self._preprocess_mean_for_loading(mcmc_samples)
506+
413507
mean_module = ConstantMean(batch_shape=batch_shape).to(**tkwargs)
414-
outputscale = mcmc_samples.get("outputscale")
508+
outputscale = mean_mcmc.get("outputscale")
415509
covar_module = self._get_covar_module(
416510
use_scale_kernel=outputscale is not None, batch_shape=batch_shape, **tkwargs
417511
)
@@ -446,7 +540,7 @@ def load_mcmc_samples(
446540
)
447541
mean_module.constant.data = reshape_and_detach(
448542
target=mean_module.constant.data,
449-
new_value=mcmc_samples["mean"],
543+
new_value=mean_mcmc["mean"],
450544
)
451545
if self.use_input_warping:
452546
indices = (
@@ -470,6 +564,14 @@ def load_mcmc_samples(
470564
)
471565
else:
472566
warping_function = None
567+
568+
mean_module, covar_module = self._finalize_multitask_modules(
569+
mcmc_samples=mcmc_samples,
570+
mean_module=mean_module,
571+
covar_module=covar_module,
572+
batch_shape=batch_shape,
573+
**tkwargs,
574+
)
473575
return mean_module, covar_module, likelihood, warping_function
474576

475577

@@ -529,6 +631,18 @@ def postprocess_mcmc_samples(
529631
del mcmc_samples["kernel_tausq"], mcmc_samples["_kernel_inv_length_sq"]
530632
return mcmc_samples
531633

634+
def get_dummy_mcmc_samples(
635+
self,
636+
num_mcmc_samples: int,
637+
**tkwargs: Any,
638+
) -> dict[str, Tensor]:
639+
"""Return dummy MCMC samples for state dict loading."""
640+
mcmc_samples = super().get_dummy_mcmc_samples(
641+
num_mcmc_samples=num_mcmc_samples, **tkwargs
642+
)
643+
mcmc_samples["outputscale"] = torch.ones(num_mcmc_samples, **tkwargs)
644+
return mcmc_samples
645+
532646

533647
class LinearPyroModel(PyroModel):
534648
r"""Implementation of a Bayesian Linear pyro model.
@@ -546,9 +660,13 @@ def sample(self) -> None:
546660
mean = self.sample_mean(**tkwargs)
547661
weight_variance = self.sample_weight_variance(**tkwargs)
548662
X_tf = self._maybe_input_warp(X=self.train_X, **tkwargs)
663+
X_tf = self._prepare_features(X_tf, **tkwargs)
549664
X_tf = X_tf - 0.5 # center transformed data at 0 (for linear model)
550665
K_noiseless = linear_kernel(X=X_tf, weight_variance=weight_variance)
551666
noise = self.sample_noise(**tkwargs)
667+
K_noiseless, mean = self._maybe_multitask_transform(
668+
K_noiseless, mean, **tkwargs
669+
)
552670
self.sample_observations(
553671
mean=mean, K_noiseless=K_noiseless, noise=noise, **tkwargs
554672
)
@@ -589,6 +707,20 @@ def postprocess_mcmc_samples(
589707
del mcmc_samples["tau_sq"], mcmc_samples["_weight_variance_sq"]
590708
return mcmc_samples
591709

710+
def get_dummy_mcmc_samples(
711+
self,
712+
num_mcmc_samples: int,
713+
**tkwargs: Any,
714+
) -> dict[str, Tensor]:
715+
"""Return dummy MCMC samples for state dict loading."""
716+
mcmc_samples = {
717+
"mean": torch.ones(num_mcmc_samples, **tkwargs),
718+
"weight_variance": torch.ones(
719+
num_mcmc_samples, self.ard_num_dims, **tkwargs
720+
),
721+
}
722+
return self._common_dummy_samples(mcmc_samples, num_mcmc_samples, **tkwargs)
723+
592724
def load_mcmc_samples(
593725
self, mcmc_samples: dict[str, Tensor]
594726
) -> tuple[Mean, Kernel, Likelihood, InputTransform]:
@@ -975,27 +1107,6 @@ def median_lengthscale(self) -> Tensor:
9751107
lengthscale = base_kernel.lengthscale.clone()
9761108
return lengthscale.median(0).values.squeeze(0)
9771109

978-
def _get_dummy_mcmc_samples(
979-
self,
980-
num_mcmc_samples: int,
981-
dim: int,
982-
dtype: torch.dtype,
983-
device: torch.device,
984-
) -> dict[str, Tensor]:
985-
# Load some dummy samples
986-
tkwargs = {"dtype": dtype, "device": device}
987-
mcmc_samples = {
988-
"mean": torch.ones(num_mcmc_samples, **tkwargs),
989-
"lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs),
990-
}
991-
if self.pyro_model.train_Yvar is None:
992-
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
993-
994-
if self.pyro_model.use_input_warping:
995-
mcmc_samples["c0"] = torch.ones(num_mcmc_samples, dim, **tkwargs)
996-
mcmc_samples["c1"] = torch.ones(num_mcmc_samples, dim, **tkwargs)
997-
return mcmc_samples
998-
9991110
def load_state_dict(
10001111
self, state_dict: Mapping[str, Any], strict: bool = True
10011112
) -> None:
@@ -1012,11 +1123,10 @@ def load_state_dict(
10121123
the model construction logic into the Pyro model itself.
10131124
"""
10141125
raw_mean = state_dict["mean_module.raw_constant"]
1015-
mcmc_samples = self._get_dummy_mcmc_samples(
1016-
num_mcmc_samples=len(raw_mean),
1017-
dim=self.pyro_model.train_X.shape[-1],
1018-
dtype=raw_mean.dtype,
1019-
device=raw_mean.device,
1126+
num_mcmc_samples = len(raw_mean)
1127+
tkwargs = {"dtype": raw_mean.dtype, "device": raw_mean.device}
1128+
mcmc_samples = self.pyro_model.get_dummy_mcmc_samples(
1129+
num_mcmc_samples=num_mcmc_samples, **tkwargs
10201130
)
10211131
self.load_mcmc_samples(mcmc_samples=mcmc_samples)
10221132
# Load the actual samples from the state dict
@@ -1043,22 +1153,6 @@ class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP):
10431153

10441154
_pyro_model_class: type[PyroModel] = SaasPyroModel
10451155

1046-
def _get_dummy_mcmc_samples(
1047-
self,
1048-
num_mcmc_samples: int,
1049-
dim: int,
1050-
dtype: torch.dtype,
1051-
device: torch.device,
1052-
) -> dict[str, Tensor]:
1053-
mcmc_samples = super()._get_dummy_mcmc_samples(
1054-
num_mcmc_samples=num_mcmc_samples, dim=dim, dtype=dtype, device=device
1055-
)
1056-
# add outputscale
1057-
mcmc_samples["outputscale"] = torch.ones(
1058-
num_mcmc_samples, dtype=dtype, device=device
1059-
)
1060-
return mcmc_samples
1061-
10621156

10631157
class FullyBayesianLinearSingleTaskGP(AbstractFullyBayesianSingleTaskGP):
10641158
r"""A fully Bayesian single-task GP model with a linear kernel.
@@ -1102,19 +1196,10 @@ def load_state_dict(
11021196
"""
11031197
weight_variance = state_dict["covar_module.raw_variance"]
11041198
num_mcmc_samples = len(weight_variance)
1105-
dim = self.pyro_model.train_X.shape[-1]
11061199
tkwargs = {"device": weight_variance.device, "dtype": weight_variance.dtype}
1107-
# Load some dummy samples
1108-
# deal with c0 c1
1109-
mcmc_samples = {
1110-
"mean": torch.ones(num_mcmc_samples, **tkwargs),
1111-
"weight_variance": torch.ones(num_mcmc_samples, dim, **tkwargs),
1112-
}
1113-
if self.pyro_model.use_input_warping:
1114-
mcmc_samples["c0"] = torch.ones(num_mcmc_samples, dim, **tkwargs)
1115-
mcmc_samples["c1"] = torch.ones(num_mcmc_samples, dim, **tkwargs)
1116-
if self.pyro_model.train_Yvar is None:
1117-
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
1200+
mcmc_samples = self.pyro_model.get_dummy_mcmc_samples(
1201+
num_mcmc_samples=num_mcmc_samples, **tkwargs
1202+
)
11181203
self.load_mcmc_samples(mcmc_samples=mcmc_samples)
11191204
# Load the actual samples from the state dict
11201205
super().load_state_dict(state_dict=state_dict, strict=strict)

0 commit comments

Comments
 (0)