Open
Description
🐛 Bug
I found some unexpected interactions between ard_num_dims
and the shapes of priors for kernels -- a few settings where if I sample from a hyperparameter prior I don't get a tensor the same shape as the hyperparameter. I'm not sure if all of these are intended or not, but looks like a bug to me.
To reproduce
import torch
from gpytorch.priors import GammaPrior, NormalPrior
from gpytorch.kernels import RBFKernel
# make a kernel
scales = torch.Tensor([1,1])
kernel = RBFKernel(
ard_num_dims=2,
lengthscale_prior=GammaPrior(3.0, 6.0 / scales),
)
new_lengthscale = kernel.lengthscale_prior.sample(kernel.lengthscale.shape)
print(kernel.lengthscale.shape) # size 1,2
print(new_lengthscale.shape) # size 1,2,2, if I try to assign it back I get an error
# same with another prior
kernel2 = RBFKernel(
ard_num_dims=2,
lengthscale_prior=NormalPrior(loc=10, scale=scales)
)
new_lengthscale = kernel2.lengthscale_prior.sample(kernel2.lengthscale.shape)
print(kernel2.lengthscale.shape) # size 1,2
print(new_lengthscale.shape) # size 1, 2, 2
# ard_num_dims is only 1 but we have a higher-dim prior. Is this behavior defined?
kernel3 = RBFKernel(
ard_num_dims=1,
lengthscale_prior=NormalPrior(loc=10, scale=scales)
)
new_lengthscale = kernel3.lengthscale_prior.sample(kernel3.lengthscale.shape)
print(kernel3.lengthscale.shape) # size 1, 1 -- but shouldn't we expect 1,2?
print(new_lengthscale.shape) # size 1, 1, 2
# ok, ard_num_dims is 2 but my prior is 1d, now it works correctly
kernel4 = RBFKernel(
ard_num_dims=2,
lengthscale_prior=NormalPrior(loc=10, scale=1)
)
new_lengthscale = kernel4.lengthscale_prior.sample(kernel4.lengthscale.shape)
print(kernel4.lengthscale.shape) # size 1,2
print(new_lengthscale.shape) # size 1, 2
Expected Behavior
It would be nice if we got a warning/error earlier for undefined/unsupported behavior, and otherwise shapes matched correctly.
System information
Please complete the following information:
- GPyTorch Version: 1.2.0
- PyTorch Version: 1.6.0.
- Computer OS: verified on OSX and CentOS.
Metadata
Metadata
Assignees
Type
Projects
Status
Todo