Skip to content

[Bug] load_state_dict doesn't invalidate cached transformed inputs #1471

Open
@mrcslws

Description

@mrcslws

🐛 Bug

botorch models do not properly support model.load_state_dict when using input_transforms with trained parameters. After performing model.load_state_dict, the model continues using cached transformed inputs that were computed with the previous parameters.

gpytorch doesn't have this bug in its caching; it intentionally clears all caches whenever loading the state dict.

Workaround: call model.train() before calling model.load_state_dict().

To reproduce

import copy

import botorch
import gpytorch
import torch


train_X = torch.tensor([[0., 0.],
                        [0.1, 0.1],
                        [0.5, 0.5]])
train_Y = torch.tensor([[0.],
                        [0.5],
                        [1.0]])
test_X = torch.tensor([[0.3, 0.3]])

model = botorch.models.SingleTaskGP(
    train_X, train_Y,
    # This is one example input transform that stores trained parameters in the
    # state dict
    input_transform=botorch.models.transforms.Warp(indices=[0, 1]))

state_dict = copy.deepcopy(model.state_dict())

# Check initial behavior
model.eval()
print(f"Before: {model(test_X).mean.item()}")

# Train model, adjusting the Warp parameters and caching transformed inputs
botorch.fit_gpytorch_mll(
    gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
)

# Workaround: uncomment following line
# model.train()

# Revert to original parameters
model.load_state_dict(state_dict)

# Verify that output matches original output
model.eval()
print(f"After: {model(test_X).mean.item()}")

Actual output

Before: 0.2642167806625366
After: 0.21983212232589722

Expected output

Before: 0.2642167806625366
After: 0.2642167806625366

System information

BoTorch 0.7.2
GPyTorch 1.9.0
PyTorch 1.13.0
MacOS 13.0

Metadata

Metadata

Assignees

Labels

WIPWork in ProgressbugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions