Skip to content

[Bug] Inputs are not transformed when using GPyTorchModel.condition_on_observations #2533

Open
@mikkelbue

Description

@mikkelbue

🐛 Bug

The inputs X are not transformed by the model's input transform when running GPyTorchModel.condition_on_observations(X, y). The outcomes Y, however, are transformed. This means that the new observations are not scaled correctly when added to the fantasy model's train_inputs tensor.

To reproduce

Set up a model

import torch
from botorch.models import SingleTaskGP
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch import fit_gpytorch_mll
import matplotlib.pyplot as plt

_ = torch.manual_seed(0)

# modified forrester with domain [0,10]
def squashed_forrester(x):
    x = x / 10
    return (6*x - 2)**2 * torch.sin(12*x - 4)

X = 10*torch.rand(size=(5,1), dtype=torch.double)
y = squashed_forrester(X)

# set up the model and fit.
model = SingleTaskGP(X,
                     y,
                     input_transform=Normalize(1),
                     outcome_transform=Standardize(1))
mll = ExactMarginalLogLikelihood(model.likelihood, model)

_ = fit_gpytorch_mll(mll)

# predict using the original model
X_test = torch.linspace(0,10,50).unsqueeze(-1)
y_post = model.posterior(X_test)
y_mean = y_post.mean.detach().squeeze()
y_stddev = y_post.stddev.detach().squeeze()

# plot the original model predictions
plt.scatter(X, y, c='tomato', zorder=10)
plt.plot(X_test.squeeze(), y_mean)
plt.fill_between(X_test.squeeze(), y_mean - 1.96*y_stddev, y_mean + 1.96*y_stddev, alpha=0.2)
plt.show()

image

Condition on some data

# add 3 new data.
X_cond = 10*torch.rand(size=(3,1), dtype=torch.double)
y_cond = squashed_forrester(X_cond)

# condition on the new data
new_model = model.condition_on_observations(X_cond, y_cond)

# predict using the new model.
X_test = torch.linspace(0,10,50).unsqueeze(-1)
y_post = new_model.posterior(X_test)
y_mean = y_post.mean.detach().squeeze()
y_stddev = y_post.stddev.detach().squeeze()

# plot the predictions of the new model.
plt.scatter(X, y, c='tomato', zorder=10)
plt.scatter(X_cond, y_cond, c='goldenrod', zorder=10)
plt.plot(X_test.squeeze(), y_mean)
plt.fill_between(X_test.squeeze(), y_mean - 1.96*y_stddev, y_mean + 1.96*y_stddev, alpha=0.2)
plt.show()

image

Condition on some data, but transform the inputs first

# condition the model, but transform the inputs frst
new_model = model.condition_on_observations(model.transform_inputs(X_cond), y_cond)

# predict using the new model.
X_test = torch.linspace(0,10,50).unsqueeze(-1)
y_post = new_model.posterior(X_test)
y_mean = y_post.mean.detach().squeeze()

# plot the predictions of the new model.
plt.scatter(X, y, c='tomato', zorder=10)
plt.scatter(X_cond, y_cond, c='goldenrod', zorder=10)
plt.plot(X_test.squeeze(), y_mean)
plt.fill_between(X_test.squeeze(), y_mean - 1.96*y_stddev, y_mean + 1.96*y_stddev, alpha=0.2)
plt.show()

image

Expected Behavior

I would expect the conditioned inputs to be scaled using the self.transform_inputs, as is also happening when computing the posterior, before adding the new data to the train_inputs.

System information

Please complete the following information:

  • BoTorch Version: 0.11.3
  • GPyTorch Version: 1.12
  • PyTorch Version: 2.4.2+cu121
  • Ubuntu 22.04.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions