Skip to content

Commit 88138ed

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Add multi-task mixin support to PyroModel hierarchy
Summary: Generalize PyroModel with dispatch methods and multi-task mixins so that any PyroModel subclass can be composed with multi-task capabilities.. Reviewed By: saitcakmak, sdaulton Differential Revision: D92844567
1 parent 7b5e35b commit 88138ed

4 files changed

Lines changed: 545 additions & 200 deletions

File tree

botorch/models/fully_bayesian.py

Lines changed: 148 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,23 @@ def _maybe_input_warp(self, X: Tensor, **tkwargs: Any) -> Tensor:
165165
return self.train_X
166166

167167
def set_inputs(
168-
self, train_X: Tensor, train_Y: Tensor, train_Yvar: Tensor | None = None
168+
self,
169+
train_X: Tensor,
170+
train_Y: Tensor,
171+
train_Yvar: Tensor | None = None,
172+
task_feature: int | None = None,
173+
task_rank: int | None = None,
169174
) -> None:
170175
"""Set the training data.
171176
172177
Args:
173178
train_X: Training inputs (n x d)
174179
train_Y: Training targets (n x 1)
175180
train_Yvar: Observed noise variance (n x 1). Inferred if None.
181+
task_feature: The index of the task feature column. Used by
182+
multi-task mixins.
183+
task_rank: The number of learned task embeddings. Used by
184+
multi-task mixins.
176185
"""
177186
self.train_X = train_X
178187
self.train_Y = train_Y
@@ -192,6 +201,19 @@ def postprocess_mcmc_samples(
192201
"""Post-process the final MCMC samples."""
193202
pass # pragma: no cover
194203

204+
@abstractmethod
205+
def get_dummy_mcmc_samples(
206+
self,
207+
num_mcmc_samples: int,
208+
**tkwargs: Any,
209+
) -> dict[str, Tensor]:
210+
"""Return dummy MCMC samples for state dict loading.
211+
212+
Each subclass returns a dict of ones with the keys and shapes that
213+
``load_mcmc_samples`` expects.
214+
"""
215+
pass # pragma: no cover
216+
195217
@abstractmethod
196218
def load_mcmc_samples(
197219
self, mcmc_samples: dict[str, Tensor]
@@ -245,6 +267,68 @@ def sample_concentrations(self, **tkwargs: Any) -> tuple[Tensor, Tensor]:
245267

246268
return c0, c1
247269

270+
def _common_dummy_samples(
271+
self,
272+
mcmc_samples: dict[str, Tensor],
273+
num_mcmc_samples: int,
274+
**tkwargs: Any,
275+
) -> dict[str, Tensor]:
276+
"""Add noise and warping entries to ``mcmc_samples`` in-place."""
277+
dim = self.ard_num_dims
278+
if self.train_Yvar is None:
279+
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
280+
if self.use_input_warping:
281+
mcmc_samples["c0"] = torch.ones(num_mcmc_samples, dim, **tkwargs)
282+
mcmc_samples["c1"] = torch.ones(num_mcmc_samples, dim, **tkwargs)
283+
return mcmc_samples
284+
285+
def _prepare_features(self, X: Tensor, **tkwargs: Any) -> Tensor:
286+
"""Select feature columns for kernel computation.
287+
288+
Overridden by multi-task mixins to strip the task column.
289+
"""
290+
return X
291+
292+
def _maybe_multitask_transform(
293+
self, K_noiseless: Tensor, mean: Tensor, **tkwargs: Any
294+
) -> tuple[Tensor, Tensor]:
295+
r"""Apply multi-task covariance transform to the kernel and mean.
296+
297+
No-op for single-task models. Overridden by multi-task mixins.
298+
"""
299+
return K_noiseless, mean
300+
301+
def _build_mean_module(
302+
self,
303+
mcmc_samples: dict[str, Tensor],
304+
batch_shape: torch.Size,
305+
**tkwargs: Any,
306+
) -> Mean:
307+
"""Construct and populate the mean module from MCMC samples.
308+
309+
Returns a scalar ``ConstantMean`` for single-task models. Overridden by
310+
multi-task mixins to return a ``MultitaskMean``.
311+
"""
312+
mean_module = ConstantMean(batch_shape=batch_shape).to(**tkwargs)
313+
mean_module.constant.data = reshape_and_detach(
314+
target=mean_module.constant.data,
315+
new_value=mcmc_samples["mean"],
316+
)
317+
return mean_module
318+
319+
def _build_multitask_covariance(
320+
self,
321+
mcmc_samples: dict[str, Tensor],
322+
covar_module: Kernel,
323+
batch_shape: torch.Size,
324+
**tkwargs: Any,
325+
) -> Kernel:
326+
"""Optionally wrap covar_module with task covariance.
327+
328+
No-op for single-task models. Overridden by multi-task mixins.
329+
"""
330+
return covar_module
331+
248332
def sample_observations(
249333
self,
250334
mean: Tensor,
@@ -322,7 +406,11 @@ def sample(self) -> None:
322406
noise = self.sample_noise(**tkwargs)
323407
lengthscale = self.sample_lengthscale(dim=self.ard_num_dims, **tkwargs)
324408
X_tf = self._maybe_input_warp(self.train_X, **tkwargs)
409+
X_tf = self._prepare_features(X_tf, **tkwargs)
325410
K_noiseless = outputscale * matern52_kernel(X=X_tf, lengthscale=lengthscale)
411+
K_noiseless, mean = self._maybe_multitask_transform(
412+
K_noiseless, mean, **tkwargs
413+
)
326414
self.sample_observations(
327415
mean=mean, K_noiseless=K_noiseless, noise=noise, **tkwargs
328416
)
@@ -375,6 +463,18 @@ def postprocess_mcmc_samples(
375463
"""
376464
return mcmc_samples
377465

466+
def get_dummy_mcmc_samples(
467+
self,
468+
num_mcmc_samples: int,
469+
**tkwargs: Any,
470+
) -> dict[str, Tensor]:
471+
"""Return dummy MCMC samples for state dict loading."""
472+
mcmc_samples = {
473+
"mean": torch.ones(num_mcmc_samples, **tkwargs),
474+
"lengthscale": torch.ones(num_mcmc_samples, self.ard_num_dims, **tkwargs),
475+
}
476+
return self._common_dummy_samples(mcmc_samples, num_mcmc_samples, **tkwargs)
477+
378478
def _get_covar_module(
379479
self,
380480
use_scale_kernel: bool,
@@ -410,7 +510,9 @@ def load_mcmc_samples(
410510
num_mcmc_samples = len(mcmc_samples["mean"])
411511
batch_shape = torch.Size([num_mcmc_samples])
412512

413-
mean_module = ConstantMean(batch_shape=batch_shape).to(**tkwargs)
513+
mean_module = self._build_mean_module(
514+
mcmc_samples=mcmc_samples, batch_shape=batch_shape, **tkwargs
515+
)
414516
outputscale = mcmc_samples.get("outputscale")
415517
covar_module = self._get_covar_module(
416518
use_scale_kernel=outputscale is not None, batch_shape=batch_shape, **tkwargs
@@ -444,10 +546,6 @@ def load_mcmc_samples(
444546
target=base_kernel.lengthscale,
445547
new_value=mcmc_samples["lengthscale"],
446548
)
447-
mean_module.constant.data = reshape_and_detach(
448-
target=mean_module.constant.data,
449-
new_value=mcmc_samples["mean"],
450-
)
451549
if self.use_input_warping:
452550
indices = (
453551
list(range(self.ard_num_dims)) if self.indices is None else self.indices
@@ -470,6 +568,13 @@ def load_mcmc_samples(
470568
)
471569
else:
472570
warping_function = None
571+
572+
covar_module = self._build_multitask_covariance(
573+
mcmc_samples=mcmc_samples,
574+
covar_module=covar_module,
575+
batch_shape=batch_shape,
576+
**tkwargs,
577+
)
473578
return mean_module, covar_module, likelihood, warping_function
474579

475580

@@ -529,6 +634,18 @@ def postprocess_mcmc_samples(
529634
del mcmc_samples["kernel_tausq"], mcmc_samples["_kernel_inv_length_sq"]
530635
return mcmc_samples
531636

637+
def get_dummy_mcmc_samples(
638+
self,
639+
num_mcmc_samples: int,
640+
**tkwargs: Any,
641+
) -> dict[str, Tensor]:
642+
"""Return dummy MCMC samples for state dict loading."""
643+
mcmc_samples = super().get_dummy_mcmc_samples(
644+
num_mcmc_samples=num_mcmc_samples, **tkwargs
645+
)
646+
mcmc_samples["outputscale"] = torch.ones(num_mcmc_samples, **tkwargs)
647+
return mcmc_samples
648+
532649

533650
class LinearPyroModel(PyroModel):
534651
r"""Implementation of a Bayesian Linear pyro model.
@@ -546,9 +663,13 @@ def sample(self) -> None:
546663
mean = self.sample_mean(**tkwargs)
547664
weight_variance = self.sample_weight_variance(**tkwargs)
548665
X_tf = self._maybe_input_warp(X=self.train_X, **tkwargs)
666+
X_tf = self._prepare_features(X_tf, **tkwargs)
549667
X_tf = X_tf - 0.5 # center transformed data at 0 (for linear model)
550668
K_noiseless = linear_kernel(X=X_tf, weight_variance=weight_variance)
551669
noise = self.sample_noise(**tkwargs)
670+
K_noiseless, mean = self._maybe_multitask_transform(
671+
K_noiseless, mean, **tkwargs
672+
)
552673
self.sample_observations(
553674
mean=mean, K_noiseless=K_noiseless, noise=noise, **tkwargs
554675
)
@@ -589,6 +710,20 @@ def postprocess_mcmc_samples(
589710
del mcmc_samples["tau_sq"], mcmc_samples["_weight_variance_sq"]
590711
return mcmc_samples
591712

713+
def get_dummy_mcmc_samples(
714+
self,
715+
num_mcmc_samples: int,
716+
**tkwargs: Any,
717+
) -> dict[str, Tensor]:
718+
"""Return dummy MCMC samples for state dict loading."""
719+
mcmc_samples = {
720+
"mean": torch.ones(num_mcmc_samples, **tkwargs),
721+
"weight_variance": torch.ones(
722+
num_mcmc_samples, self.ard_num_dims, **tkwargs
723+
),
724+
}
725+
return self._common_dummy_samples(mcmc_samples, num_mcmc_samples, **tkwargs)
726+
592727
def load_mcmc_samples(
593728
self, mcmc_samples: dict[str, Tensor]
594729
) -> tuple[Mean, Kernel, Likelihood, InputTransform]:
@@ -975,27 +1110,6 @@ def median_lengthscale(self) -> Tensor:
9751110
lengthscale = base_kernel.lengthscale.clone()
9761111
return lengthscale.median(0).values.squeeze(0)
9771112

978-
def _get_dummy_mcmc_samples(
979-
self,
980-
num_mcmc_samples: int,
981-
dim: int,
982-
dtype: torch.dtype,
983-
device: torch.device,
984-
) -> dict[str, Tensor]:
985-
# Load some dummy samples
986-
tkwargs = {"dtype": dtype, "device": device}
987-
mcmc_samples = {
988-
"mean": torch.ones(num_mcmc_samples, **tkwargs),
989-
"lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs),
990-
}
991-
if self.pyro_model.train_Yvar is None:
992-
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
993-
994-
if self.pyro_model.use_input_warping:
995-
mcmc_samples["c0"] = torch.ones(num_mcmc_samples, dim, **tkwargs)
996-
mcmc_samples["c1"] = torch.ones(num_mcmc_samples, dim, **tkwargs)
997-
return mcmc_samples
998-
9991113
def load_state_dict(
10001114
self, state_dict: Mapping[str, Any], strict: bool = True
10011115
) -> None:
@@ -1012,11 +1126,10 @@ def load_state_dict(
10121126
the model construction logic into the Pyro model itself.
10131127
"""
10141128
raw_mean = state_dict["mean_module.raw_constant"]
1015-
mcmc_samples = self._get_dummy_mcmc_samples(
1016-
num_mcmc_samples=len(raw_mean),
1017-
dim=self.pyro_model.train_X.shape[-1],
1018-
dtype=raw_mean.dtype,
1019-
device=raw_mean.device,
1129+
num_mcmc_samples = len(raw_mean)
1130+
tkwargs = {"dtype": raw_mean.dtype, "device": raw_mean.device}
1131+
mcmc_samples = self.pyro_model.get_dummy_mcmc_samples(
1132+
num_mcmc_samples=num_mcmc_samples, **tkwargs
10201133
)
10211134
self.load_mcmc_samples(mcmc_samples=mcmc_samples)
10221135
# Load the actual samples from the state dict
@@ -1043,22 +1156,6 @@ class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP):
10431156

10441157
_pyro_model_class: type[PyroModel] = SaasPyroModel
10451158

1046-
def _get_dummy_mcmc_samples(
1047-
self,
1048-
num_mcmc_samples: int,
1049-
dim: int,
1050-
dtype: torch.dtype,
1051-
device: torch.device,
1052-
) -> dict[str, Tensor]:
1053-
mcmc_samples = super()._get_dummy_mcmc_samples(
1054-
num_mcmc_samples=num_mcmc_samples, dim=dim, dtype=dtype, device=device
1055-
)
1056-
# add outputscale
1057-
mcmc_samples["outputscale"] = torch.ones(
1058-
num_mcmc_samples, dtype=dtype, device=device
1059-
)
1060-
return mcmc_samples
1061-
10621159

10631160
class FullyBayesianLinearSingleTaskGP(AbstractFullyBayesianSingleTaskGP):
10641161
r"""A fully Bayesian single-task GP model with a linear kernel.
@@ -1102,19 +1199,10 @@ def load_state_dict(
11021199
"""
11031200
weight_variance = state_dict["covar_module.raw_variance"]
11041201
num_mcmc_samples = len(weight_variance)
1105-
dim = self.pyro_model.train_X.shape[-1]
11061202
tkwargs = {"device": weight_variance.device, "dtype": weight_variance.dtype}
1107-
# Load some dummy samples
1108-
# deal with c0 c1
1109-
mcmc_samples = {
1110-
"mean": torch.ones(num_mcmc_samples, **tkwargs),
1111-
"weight_variance": torch.ones(num_mcmc_samples, dim, **tkwargs),
1112-
}
1113-
if self.pyro_model.use_input_warping:
1114-
mcmc_samples["c0"] = torch.ones(num_mcmc_samples, dim, **tkwargs)
1115-
mcmc_samples["c1"] = torch.ones(num_mcmc_samples, dim, **tkwargs)
1116-
if self.pyro_model.train_Yvar is None:
1117-
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
1203+
mcmc_samples = self.pyro_model.get_dummy_mcmc_samples(
1204+
num_mcmc_samples=num_mcmc_samples, **tkwargs
1205+
)
11181206
self.load_mcmc_samples(mcmc_samples=mcmc_samples)
11191207
# Load the actual samples from the state dict
11201208
super().load_state_dict(state_dict=state_dict, strict=strict)

0 commit comments

Comments
 (0)