Open
Description
🐛 Bug
The inputs X
are not transformed by the model's input transform when running GPyTorchModel.condition_on_observations(X, y)
. The outcomes Y
, however, are transformed. This means that the new observations are not scaled correctly when added to the fantasy model's train_inputs
tensor.
To reproduce
Set up a model
import torch
from botorch.models import SingleTaskGP
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch import fit_gpytorch_mll
import matplotlib.pyplot as plt
_ = torch.manual_seed(0)
# modified forrester with domain [0,10]
def squashed_forrester(x):
x = x / 10
return (6*x - 2)**2 * torch.sin(12*x - 4)
X = 10*torch.rand(size=(5,1), dtype=torch.double)
y = squashed_forrester(X)
# set up the model and fit.
model = SingleTaskGP(X,
y,
input_transform=Normalize(1),
outcome_transform=Standardize(1))
mll = ExactMarginalLogLikelihood(model.likelihood, model)
_ = fit_gpytorch_mll(mll)
# predict using the original model
X_test = torch.linspace(0,10,50).unsqueeze(-1)
y_post = model.posterior(X_test)
y_mean = y_post.mean.detach().squeeze()
y_stddev = y_post.stddev.detach().squeeze()
# plot the original model predictions
plt.scatter(X, y, c='tomato', zorder=10)
plt.plot(X_test.squeeze(), y_mean)
plt.fill_between(X_test.squeeze(), y_mean - 1.96*y_stddev, y_mean + 1.96*y_stddev, alpha=0.2)
plt.show()
Condition on some data
# add 3 new data.
X_cond = 10*torch.rand(size=(3,1), dtype=torch.double)
y_cond = squashed_forrester(X_cond)
# condition on the new data
new_model = model.condition_on_observations(X_cond, y_cond)
# predict using the new model.
X_test = torch.linspace(0,10,50).unsqueeze(-1)
y_post = new_model.posterior(X_test)
y_mean = y_post.mean.detach().squeeze()
y_stddev = y_post.stddev.detach().squeeze()
# plot the predictions of the new model.
plt.scatter(X, y, c='tomato', zorder=10)
plt.scatter(X_cond, y_cond, c='goldenrod', zorder=10)
plt.plot(X_test.squeeze(), y_mean)
plt.fill_between(X_test.squeeze(), y_mean - 1.96*y_stddev, y_mean + 1.96*y_stddev, alpha=0.2)
plt.show()
Condition on some data, but transform the inputs first
# condition the model, but transform the inputs frst
new_model = model.condition_on_observations(model.transform_inputs(X_cond), y_cond)
# predict using the new model.
X_test = torch.linspace(0,10,50).unsqueeze(-1)
y_post = new_model.posterior(X_test)
y_mean = y_post.mean.detach().squeeze()
# plot the predictions of the new model.
plt.scatter(X, y, c='tomato', zorder=10)
plt.scatter(X_cond, y_cond, c='goldenrod', zorder=10)
plt.plot(X_test.squeeze(), y_mean)
plt.fill_between(X_test.squeeze(), y_mean - 1.96*y_stddev, y_mean + 1.96*y_stddev, alpha=0.2)
plt.show()
Expected Behavior
I would expect the conditioned inputs to be scaled using the self.transform_inputs
, as is also happening when computing the posterior, before adding the new data to the train_inputs
.
System information
Please complete the following information:
- BoTorch Version: 0.11.3
- GPyTorch Version: 1.12
- PyTorch Version: 2.4.2+cu121
- Ubuntu 22.04.4