[Bug]: Inconsistent results from LinearMCObjective
#2681
Open
Description
What happened?
I'm currently playing around with mechanisms for function minimization as suggested here and noticed some strange behavior. Potentially I'm overlooking some crucial aspect or maybe there is simply a flaw in my approach. But what I observe is that the acquisition values obtained when using a LinearMCObjective
do not match up when maximizing a GP or minimizing an equivalent GP trained on the negated target values.
Please provide a minimal, reproducible example of the unexpected behavior.
Here minimal example inspired by the code from the landing page. With qProbabilityOfImprovement
, the differences are even more striking.
import random
import numpy as np
import torch
from botorch.acquisition import qLogExpectedImprovement
from botorch.acquisition.objective import LinearMCObjective
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from gpytorch.mlls import ExactMarginalLogLikelihood
train_X = torch.rand(5, 2, dtype=torch.double) * 2
Y = 1 - torch.linalg.norm(train_X - 0.5, dim=-1, keepdim=True)
Y = Y + 0.1 * torch.randn_like(Y)
def fix_seed(seed=0):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def run(minimize: bool):
if minimize:
train_Y = -Y
obj = LinearMCObjective(weights=torch.tensor([-1.0]))
best_f = -Y.min()
else:
train_Y = Y
obj = LinearMCObjective(weights=torch.tensor([1.0]))
best_f = Y.max()
gp = SingleTaskGP(
train_X=train_X,
train_Y=train_Y,
input_transform=Normalize(d=2),
outcome_transform=Standardize(m=1),
)
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fix_seed(0)
fit_gpytorch_mll(mll)
posterior = gp(train_X)
acqf = qLogExpectedImprovement(model=gp, best_f=best_f, objective=obj)
fix_seed(0)
acqf_value = acqf(train_X.unsqueeze(1))
return posterior, acqf_value
posterior_max, acqf_max = run(False)
posterior_min, acqf_min = run(True)
# The predictions are equivalent
assert torch.equal(posterior_max.mean, -posterior_min.mean)
assert torch.equal(posterior_max.stddev, posterior_min.stddev)
# But the acquisition values do not match
print(acqf_max.tolist())
print(acqf_min.tolist())
Sometimes, the number "sort of match up" like:
[-41.65546159698638, -44.19458355943439, -5.794877518131068, -3.9572762719339836, -43.939276772549725]
[-41.54524809215892, -44.165510927608004, -5.034962168263423, -3.5053637720690194, -43.90618926555327]
But oftentimes, there are very significant differences:
[-41.253846088273036, -42.4522252849054, -43.566397044168006, -4.586150277498735, -43.92969884779241]
[-5.776667325514115, -41.053594326714006, -42.89699038185324, -1.3871302702025305, -43.38882095231844]
Is something wrong with my logic?
Please paste any relevant traceback/logs produced by the example provided.
BoTorch Version
0.12.0
Python Version
3.10.14
Operating System
macOS
Code of Conduct
- I agree to follow BoTorch's Code of Conduct