Open
Description
🐛 Bug
botorch models do not properly support model.load_state_dict
when using input_transform
s with trained parameters. After performing model.load_state_dict
, the model continues using cached transformed inputs that were computed with the previous parameters.
gpytorch doesn't have this bug in its caching; it intentionally clears all caches whenever loading the state dict.
Workaround: call model.train()
before calling model.load_state_dict()
.
To reproduce
import copy
import botorch
import gpytorch
import torch
train_X = torch.tensor([[0., 0.],
[0.1, 0.1],
[0.5, 0.5]])
train_Y = torch.tensor([[0.],
[0.5],
[1.0]])
test_X = torch.tensor([[0.3, 0.3]])
model = botorch.models.SingleTaskGP(
train_X, train_Y,
# This is one example input transform that stores trained parameters in the
# state dict
input_transform=botorch.models.transforms.Warp(indices=[0, 1]))
state_dict = copy.deepcopy(model.state_dict())
# Check initial behavior
model.eval()
print(f"Before: {model(test_X).mean.item()}")
# Train model, adjusting the Warp parameters and caching transformed inputs
botorch.fit_gpytorch_mll(
gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
)
# Workaround: uncomment following line
# model.train()
# Revert to original parameters
model.load_state_dict(state_dict)
# Verify that output matches original output
model.eval()
print(f"After: {model(test_X).mean.item()}")
Actual output
Before: 0.2642167806625366
After: 0.21983212232589722
Expected output
Before: 0.2642167806625366
After: 0.2642167806625366
System information
BoTorch 0.7.2
GPyTorch 1.9.0
PyTorch 1.13.0
MacOS 13.0