Open
Description
Issue description
I'm running sample_optimal_points
with a model that has input and outcome transforms. I get this warning:
RuntimeWarning: Could not update `train_inputs` with transformed inputs since GenericDeterministicModel does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.
Reading through the source code, it seems like model is wrapped GenericDeterministicModel
, which does not support transforms. What is the best way we can account for transforms when we run sample_optimal_points
?
For context, I'm sampling optimal points to run q-multi-objective PES, MES, and JES. Thanks in advance for your help!
Code example
Below, I simulate inputs x and/or outputs y with extreme values and optimize PES repeatedly. As expected, sample_optimal_points
fails eventually with RuntimeError: Only found 1 optimal points instead of 20.
import torch
from botorch.test_functions.multi_objective import BraninCurrin
from botorch.acquisition.multi_objective.utils import (
sample_optimal_points,
random_search_optimizer,
)
from botorch.utils.sampling import draw_sobol_samples
from botorch.models.transforms import Standardize, Normalize
from botorch.models.gp_regression import SingleTaskGP
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from botorch.fit import fit_gpytorch_mll
from botorch.acquisition.multi_objective.predictive_entropy_search import qMultiObjectivePredictiveEntropySearch
from botorch.optim.optimize import optimize_acqf
def generate_data(n, seed):
tkwargs = {"dtype": torch.double, "device": "cpu"}
problem = BraninCurrin(negate=True)
bounds = problem.bounds.to(**tkwargs)
# print(f"problem bounds: {bounds}")
# print(f"ref point: {problem.ref_point}")
train_X = draw_sobol_samples(bounds=bounds, n=n, q=1, seed=seed).squeeze(-2)
train_Y = problem(train_X)
# Rescale to extreme values
train_X = (train_X - 100.0)*20.0
# train_Y = (train_Y - 100.0)*20.0
return train_X, train_Y, bounds
def fit_model(train_X, train_Y):
d = train_X.shape[-1]
M = train_Y.shape[-1]
model = SingleTaskGP(
train_X, train_Y,
input_transform=Normalize(d=d),
outcome_transform=Standardize(m=M))
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)
return model
if __name__ == "__main__":
import botorch
import gpytorch
print(botorch.__version__)
print(gpytorch.__version__)
print(torch.__version__)
d = 2
M = 2
n = 6
init_X, init_Y, bounds = generate_data(n, seed=0)
model = fit_model(init_X, init_Y)
for trial_i in range(100):
print(trial_i)
ps, _ = sample_optimal_points(
model=model,
bounds=bounds,
num_samples=20,
num_points=20,
optimizer=random_search_optimizer,
optimizer_kwargs={'pop_size': 2000, 'max_tries': 10}
)
pes = qMultiObjectivePredictiveEntropySearch(model=model, pareto_sets=ps)
new_x, _ = optimize_acqf(
acq_function=pes,
bounds=bounds,
q=1,
num_restarts=10,
raw_samples=512,
sequential=True,
)
print(new_x)
System Info
Please provide information about your setup, including
- BoTorch Version (run
print(botorch.__version__)
0.9.5 - GPyTorch Version (run
print(gpytorch.__version__)
1.11 - PyTorch Version (run
print(torch.__version__)
2.0.0 - Computer OS: linux