input_transform Normalize does not seem to work properly with condition_on_observations #1435
Open
Description
🐛 Bug
After running model.condition_on_observations(new_x, new_y)
, where the original model was instantiated with Normalize(d)
, that model fails during retraining. I believe this is a bug but I'm honestly not sure.
To reproduce
Step 1: initialize dummy data
import botorch
import numpy as np
import torch
np.random.seed(123)
torch.manual_seed(123)
# use regular spaced points on the interval [0, 1]
train_x = torch.linspace(0, 1, 15)
# training data needs to be explicitly multi-dimensional
train_x = train_x.unsqueeze(1)
# sample observed values and add some synthetic noise
train_y = torch.sin(train_x * (2 * np.pi)) + 0.15 * torch.randn_like(train_x)
Step 2: initialization/training, works just fine
model = botorch.models.SingleTaskGP(
train_X=train_x, train_Y=train_y, input_transform=Normalize(1, transform_on_eval=True)
)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
botorch.fit.fit_gpytorch_mll(mll)
Step 3: condition
new_x = torch.FloatTensor(np.array([1.25, 1.5]).reshape(-1, 1))
new_y = torch.FloatTensor(np.array([-1.0, -2.0]).reshape(-1, 1))
model = model.condition_on_observations(new_x, new_y)
Step 4: attempt retraining to further tune hyper parameters/length scales and whatnot
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
botorch.fit.fit_gpytorch_mll(mll) # fails
Stack trace/error message
MDNotImplementedError Traceback (most recent call last)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/utils/dispatcher.py:88, in Dispatcher.__call__(self, *args, **kwargs)
87 try:
---> 88 return func(*args, **kwargs)
89 except MDNotImplementedError:
90 # Traverses registered methods in order, yields whenever a match is found
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/fit.py:320, in _fit_multioutput_independent(mll, _, __, sequential, **kwargs)
315 if ( # incompatible models
316 not sequential
317 or mll.model.num_outputs == 1
318 or mll.likelihood is not getattr(mll.model, "likelihood", None)
319 ):
--> 320 raise MDNotImplementedError # defer to generic
322 # TODO: Unpacking of OutcomeTransforms not yet supported. Targets are often
323 # pre-transformed in __init__, so try fitting with outcome_transform hidden
MDNotImplementedError:
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
Input In [20], in <cell line: 2>()
1 mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
----> 2 botorch.fit.fit_gpytorch_mll(mll)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/fit.py:114, in fit_gpytorch_mll(mll, optimizer, optimizer_kwargs, **kwargs)
111 if optimizer is not None: # defer to per-method defaults
112 kwargs["optimizer"] = optimizer
--> 114 return dispatcher(
115 mll,
116 type(mll.likelihood),
117 type(mll.model),
118 optimizer_kwargs=optimizer_kwargs,
119 **kwargs,
120 )
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/utils/dispatcher.py:95, in Dispatcher.__call__(self, *args, **kwargs)
93 for func in funcs:
94 try:
---> 95 return func(*args, **kwargs)
96 except MDNotImplementedError:
97 pass
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/fit.py:240, in _fit_fallback(mll, _, __, optimizer, optimizer_kwargs, max_attempts, warning_filter, caught_exception_types, **ignore)
238 with catch_warnings(record=True) as warning_list, debug(True):
239 simplefilter("always", category=OptimizationWarning)
--> 240 mll, _ = optimizer(mll, **optimizer_kwargs)
242 # Resolve warning messages and determine whether or not to retry
243 done = True
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/fit.py:142, in fit_gpytorch_scipy(mll, bounds, method, options, track_iterations, approx_mll, scipy_objective, module_to_array_func, module_from_array_func)
140 cb = store_iteration if track_iterations else None
141 with gpt_settings.fast_computations(log_prob=approx_mll):
--> 142 res = minimize(
143 scipy_objective,
144 x0,
145 args=(mll, property_dict),
146 bounds=bounds,
147 method=method,
148 jac=True,
149 options=options,
150 callback=cb,
151 )
152 iterations = []
153 if track_iterations:
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_minimize.py:692, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
689 res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
690 **options)
691 elif meth == 'l-bfgs-b':
--> 692 res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
693 callback=callback, **options)
694 elif meth == 'tnc':
695 res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
696 **options)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_lbfgsb_py.py:308, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
305 else:
306 iprint = disp
--> 308 sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps,
309 bounds=new_bounds,
310 finite_diff_rel_step=finite_diff_rel_step)
312 func_and_grad = sf.fun_and_grad
314 fortran_int = _lbfgsb.types.intvar.dtype
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_optimize.py:263, in _prepare_scalar_function(fun, x0, jac, args, bounds, epsilon, finite_diff_rel_step, hess)
259 bounds = (-np.inf, np.inf)
261 # ScalarFunction caches. Reuse of fun(x) during grad
262 # calculation reduces overall function evaluations.
--> 263 sf = ScalarFunction(fun, x0, args, grad, hess,
264 finite_diff_rel_step, bounds, epsilon=epsilon)
266 return sf
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:158, in ScalarFunction.__init__(self, fun, x0, args, grad, hess, finite_diff_rel_step, finite_diff_bounds, epsilon)
155 self.f = fun_wrapped(self.x)
157 self._update_fun_impl = update_fun
--> 158 self._update_fun()
160 # Gradient evaluation
161 if callable(grad):
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:251, in ScalarFunction._update_fun(self)
249 def _update_fun(self):
250 if not self.f_updated:
--> 251 self._update_fun_impl()
252 self.f_updated = True
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:155, in ScalarFunction.__init__.<locals>.update_fun()
154 def update_fun():
--> 155 self.f = fun_wrapped(self.x)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:137, in ScalarFunction.__init__.<locals>.fun_wrapped(x)
133 self.nfev += 1
134 # Send a copy because the user may overwrite it.
135 # Overwriting results in undefined behaviour because
136 # fun(self.x) will change self.x, with the two no longer linked.
--> 137 fx = fun(np.copy(x), *args)
138 # Make sure the function returns a true scalar
139 if not np.isscalar(fx):
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_optimize.py:76, in MemoizeJac.__call__(self, x, *args)
74 def __call__(self, x, *args):
75 """ returns the the function value """
---> 76 self._compute_if_needed(x, *args)
77 return self._value
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_optimize.py:70, in MemoizeJac._compute_if_needed(self, x, *args)
68 if not np.all(x == self.x) or self._value is None or self.jac is None:
69 self.x = np.asarray(x).copy()
---> 70 fg = self.fun(x, *args)
71 self.jac = fg[1]
72 self._value = fg[0]
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/utils.py:227, in _scipy_objective_and_grad(x, mll, property_dict)
225 loss = -mll(*args).sum()
226 except RuntimeError as e:
--> 227 return _handle_numerical_errors(error=e, x=x)
228 loss.backward()
230 i = 0
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/utils.py:256, in _handle_numerical_errors(error, x)
250 if (
251 isinstance(error, NanError)
252 or "singular" in error_message # old pytorch message
253 or "input is not positive-definite" in error_message # since pytorch #63864
254 ):
255 return float("nan"), np.full_like(x, "nan")
--> 256 raise error
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/utils.py:225, in _scipy_objective_and_grad(x, mll, property_dict)
223 output = mll.model(*train_inputs)
224 args = [output, train_targets] + _get_extra_mll_args(mll)
--> 225 loss = -mll(*args).sum()
226 except RuntimeError as e:
227 return _handle_numerical_errors(error=e, x=x)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/gpytorch/module.py:30, in Module.__call__(self, *inputs, **kwargs)
29 def __call__(self, *inputs, **kwargs):
---> 30 outputs = self.forward(*inputs, **kwargs)
31 if isinstance(outputs, list):
32 return [_validate_module_outputs(output) for output in outputs]
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:64, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params)
62 # Get the log prob of the marginal distribution
63 output = self.likelihood(function_dist, *params)
---> 64 res = output.log_prob(target)
65 res = self._add_other_terms(res, params)
67 # Scale by the amount of data we have
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/gpytorch/distributions/multivariate_normal.py:147, in MultivariateNormal.log_prob(self, value)
145 def log_prob(self, value):
146 if settings.fast_computations.log_prob.off():
--> 147 return super().log_prob(value)
149 if self._validate_args:
150 self._validate_sample(value)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/torch/distributions/multivariate_normal.py:211, in MultivariateNormal.log_prob(self, value)
209 if self._validate_args:
210 self._validate_sample(value)
--> 211 diff = value - self.loc
212 M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
213 half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
RuntimeError: The size of tensor a (17) must match the size of tensor b (15) at non-singleton dimension 0
Expected Behavior
I think the second training procedure is supposed to work, right? It would seem sensible that Normalize
would be updated with the new training information as passed during conditioning.
System information
BoTorch version: 0.7.2
GPyTorch version: 1.9.0
Torch version: 1.12.0
Computer OS: Mac M1 Max OS version 12.5.1