Skip to content

[Bug] sample_all_priors doesn't work with KroneckerMultiTaskGP #1860

Open
@esantorella

Description

@esantorella

🐛 Bug

Moving this from #1323 .

To reproduce

** Code snippet to reproduce **

import torch
from botorch import fit_gpytorch_model
from botorch.models.multitask import KroneckerMultiTaskGP
from botorch.models.transforms.outcome import Standardize
from botorch.utils.transforms import normalize
from botorch.fit import sample_all_priors
from gpytorch.mlls import ExactMarginalLogLikelihood

tkwargs = {
    "dtype": torch.double,
    "device": "cpu",
}


train_x = torch.rand(1, 3, **tkwargs)
train_obj = torch.rand(1, 2, **tkwargs)

train_x = normalize(
    train_x, torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], **tkwargs)
)

model = KroneckerMultiTaskGP(train_x, train_obj, outcome_transform=Standardize(m=2))
mll = ExactMarginalLogLikelihood(model.likelihood, model)

fit_gpytorch_model(mll)
sample_all_priors(mll.model)

** Stack trace/error message **

  File "/Users/santorella/issue_repros/botorch_1323.py", line 26, in <module>
    sample_all_priors(mll.model)
  File "/Users/santorella/repos/botorch/botorch/optim/utils/model_utils.py", line 195, in sample_all_priors
    raise RuntimeError(
RuntimeError: Must provide inverse transform to be able to sample from prior.

Expected Behavior

This should work because fit_gpytorch_mll calls sample_all_priors if model fitting fails.

System information

Please complete the following information:

  • BoTorch Version: 0.8.6.dev9+g0ad879cc
  • GPyTorch Version: 1.11
  • PyTorch Version: 1.13.0
  • MacOS

Additional context

See comments on #1323 for more info.

From @saitcakmak : "Looks like both MultitaskGaussianLikelihood and IndexKernel are missing a setting_closure."

From @Balandat :

This is most likely the LKJCovariancePrior over the intra-task correlation matrix, defined by default here:

task_covar_prior = LKJCovariancePrior(

If you trace this down this is registered here: https://github.com/cornellius-gp/gpytorch/blob/d171863c50ab16b5bfb7035e579dcbe53169e703/gpytorch/kernels/index_kernel.py#L71

Basically this would need a setting_closure. In this case we're passing a covariance matrix Sigma in, so what we'd have to do here is define the closure to take in Sigma, factor it into a correlation matrix C and the variances var, perform a root decomposition of C and then set the covar_factor and var attributes of the IndexKernel.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions