Skip to content

Variance inconsistency in HeteroskedasticSingleTaskGP #933

Closed
@mklpr

Description

hi,
in HeteroskedasticSingleTaskGP, where using different ways to compute posterior with noise, i got different results and can't explain or understand it myself, so seek for helps here.

i use four ways to compute posterior with noise,

  1. model_heter.posterior(scan_x, observation_noise=True)
  2. mll_heter.likelihood(model_heter.posterior(scan_x, observation_noise=False), scan_x)
  3. model_heter.likelihood.noise_covar.noise_model.posterior(scan_x).mean to calculate noise variance and than add variance from model_heter.posterior(scan_x, observation_noise=False) to compute total posterior variance
  4. model_heter.likelihood.noise_covar.noise_model(scan_x).mean.exp() to calculate noise variance and than add variance from model_heter.posterior(scan_x, observation_noise=False) to compute total posterior variance

method 1 and method 2 has the same results, but method 3 and method 4 different from all others, in my knowledge total posterior variance equals noise variance from noise_model plus variance from GP kernel, and verify it in SingleTaskGP, so what's wrong in HeteroskedasticSingleTaskGP? is it comes from the log transfrom and how mll_heter.likelihood(model_heter.posterior(scan_x, observation_noise=False), scan_x) process it internally? thanks.

test code

Refer to https://colab.research.google.com/drive/1dOUHQzl3aQ8hz6QUtwRrXlQBGqZadQgG#scrollTo=D0A4Cf0W_QkZ

import os
import torch
import matplotlib.pyplot as plt
import warnings
import numpy as np

plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 14

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.double

# warnings.filterwarnings('ignore')

seed = 7433
torch.manual_seed(seed)
np.random.seed(seed)

W_x = np.random.uniform(0, np.pi, size=200)
W_x = np.sort(W_x)
W_y = np.random.normal(loc=(np.sin(2.5*W_x)*np.sin(1.5*W_x)),
                       scale=(0.01 + 0.25*(1-np.sin(2.5*W_x))**2),
                       size=200)

X_train = torch.tensor(W_x.reshape(-1,1), dtype=torch.double)
y_train = torch.tensor(W_y.reshape(-1, 1), dtype=torch.double)

from botorch.models import SingleTaskGP
from gpytorch.constraints import GreaterThan
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch import fit_gpytorch_model

model = SingleTaskGP(train_X=X_train, train_Y=y_train)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
_ = fit_gpytorch_model(mll)

scan_x = torch.linspace(0, np.pi, 500, dtype=dtype).reshape(-1,1,1)

with torch.no_grad():
    scan_y = model.posterior(scan_x, observation_noise=False)
    plt.plot(scan_x.numpy().reshape(-1), scan_y.mean.reshape(-1))
    
    lower, upper = scan_y.mvn.confidence_region()
    plt.fill_between(scan_x.numpy().reshape(-1), lower.numpy().reshape(-1), upper.numpy().reshape(-1), alpha=0.2)
    
    scan_y_with_noise = model.posterior(scan_x, observation_noise=True)
    lower_with_noise, upper_with_noise = scan_y_with_noise.mvn.confidence_region()
    plt.fill_between(scan_x.numpy().reshape(-1), lower_with_noise.numpy().reshape(-1), upper_with_noise.numpy().reshape(-1), alpha=0.2)
    
    plt.scatter(X_train, y_train)
    
    plt.legend(['posterior mean', 'posterior confidence', 'posterior confidence with noise', 'observed data'])

with torch.no_grad():
    observed_var = torch.pow(model.posterior(X_train).mean - y_train, 2)

from botorch.models import HeteroskedasticSingleTaskGP

model_heter = HeteroskedasticSingleTaskGP(train_X=X_train, train_Y=y_train,
                                    train_Yvar=observed_var)
mll_heter = ExactMarginalLogLikelihood(model_heter.likelihood, model_heter)
_ = fit_gpytorch_model(mll_heter)

mll_heter.eval()
model_heter.eval()
with torch.no_grad():
    plt.figure()
    scan_y = model_heter.posterior(scan_x, observation_noise=False)
    plt.plot(scan_x.numpy().reshape(-1), scan_y.mean.reshape(-1))
    
    lower, upper = scan_y.mvn.confidence_region()
    plt.fill_between(scan_x.numpy().reshape(-1), lower.numpy().reshape(-1), upper.numpy().reshape(-1), alpha=0.2)
    
    scan_y_with_noise = model_heter.posterior(scan_x, observation_noise=True)
    lower_with_noise, upper_with_noise = scan_y_with_noise.mvn.confidence_region()
    plt.fill_between(scan_x.numpy().reshape(-1), lower_with_noise.numpy().reshape(-1), upper_with_noise.numpy().reshape(-1), alpha=0.2)

    scan_y_with_noise2 = mll_heter.likelihood(scan_y.mvn, scan_x)
    lower_with_noise2, upper_with_noise2 = scan_y_with_noise2.confidence_region()
    plt.fill_between(scan_x.numpy().reshape(-1), lower_with_noise2.numpy().reshape(-1), upper_with_noise2.numpy().reshape(-1), alpha=0.2)

    noise_var = model_heter.likelihood.noise_covar.noise_model.posterior(scan_x).mean
    std_with_noise = (scan_y.variance.reshape(-1) + noise_var.reshape(-1)).sqrt()
    plt.fill_between(scan_x.numpy().reshape(-1), (scan_y.mean.reshape(-1) - 2 * std_with_noise.reshape(-1)).numpy(),
                     (scan_y.mean.reshape(-1) + 2 * std_with_noise.reshape(-1)).numpy(), alpha=0.2)

    noise_var2 = model_heter.likelihood.noise_covar.noise_model(scan_x).mean.exp()
    std_with_noise2 = (scan_y.variance.reshape(-1) + noise_var2.reshape(-1)).sqrt()
    plt.fill_between(scan_x.numpy().reshape(-1), (scan_y.mean.reshape(-1) - 2 * std_with_noise2.reshape(-1)).numpy(),
                     (scan_y.mean.reshape(-1) + 2 * std_with_noise2.reshape(-1)).numpy(), alpha=0.2)
    
    plt.scatter(X_train, y_train)
    plt.legend(['posterior mean', 'posterior confidence', 'posterior confidence with noise', 'posterior confidence with noise2',
                'posterior confidence with noise3', 'posterior confidence with noise4' , 'observed data'])

image

image

system info

  • botorch==0.5.0
  • gpytorch==1.5.0
  • torch==1.9.0

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions