Skip to content

[Bug] Sampling from priors doesn't match shape of hyperparameters #1317

Open
@mshvartsman

Description

@mshvartsman

🐛 Bug

I found some unexpected interactions between ard_num_dims and the shapes of priors for kernels -- a few settings where if I sample from a hyperparameter prior I don't get a tensor the same shape as the hyperparameter. I'm not sure if all of these are intended or not, but looks like a bug to me.

To reproduce

import torch
from gpytorch.priors import GammaPrior, NormalPrior
from gpytorch.kernels import RBFKernel

# make a kernel
scales = torch.Tensor([1,1])
kernel = RBFKernel(
    ard_num_dims=2,
    lengthscale_prior=GammaPrior(3.0, 6.0 / scales),
)
new_lengthscale = kernel.lengthscale_prior.sample(kernel.lengthscale.shape)
print(kernel.lengthscale.shape) # size 1,2
print(new_lengthscale.shape) # size 1,2,2, if I try to assign it back I get an error

# same with another prior
kernel2 = RBFKernel(
    ard_num_dims=2,
    lengthscale_prior=NormalPrior(loc=10, scale=scales)
)

new_lengthscale = kernel2.lengthscale_prior.sample(kernel2.lengthscale.shape)
print(kernel2.lengthscale.shape) # size 1,2
print(new_lengthscale.shape) # size 1, 2, 2

# ard_num_dims is only 1 but we have a higher-dim prior. Is this behavior defined?
kernel3 = RBFKernel(
    ard_num_dims=1,
    lengthscale_prior=NormalPrior(loc=10, scale=scales)
)

new_lengthscale = kernel3.lengthscale_prior.sample(kernel3.lengthscale.shape)
print(kernel3.lengthscale.shape) # size 1, 1 -- but shouldn't we expect 1,2? 
print(new_lengthscale.shape) # size 1, 1, 2

# ok, ard_num_dims is 2 but my prior is 1d, now it works correctly
kernel4 = RBFKernel(
    ard_num_dims=2,
    lengthscale_prior=NormalPrior(loc=10, scale=1)
)

new_lengthscale = kernel4.lengthscale_prior.sample(kernel4.lengthscale.shape)
print(kernel4.lengthscale.shape) # size 1,2
print(new_lengthscale.shape) # size 1, 2

Expected Behavior

It would be nice if we got a warning/error earlier for undefined/unsupported behavior, and otherwise shapes matched correctly.

System information

Please complete the following information:

  • GPyTorch Version: 1.2.0
  • PyTorch Version: 1.6.0.
  • Computer OS: verified on OSX and CentOS.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions