Skip to content

input_transform Normalize does not seem to work properly with condition_on_observations #1435

Open
@matthewcarbone

Description

🐛 Bug

After running model.condition_on_observations(new_x, new_y), where the original model was instantiated with Normalize(d), that model fails during retraining. I believe this is a bug but I'm honestly not sure.

To reproduce

Step 1: initialize dummy data

import botorch
import numpy as np
import torch

np.random.seed(123)
torch.manual_seed(123)

# use regular spaced points on the interval [0, 1]
train_x = torch.linspace(0, 1, 15)

# training data needs to be explicitly multi-dimensional
train_x = train_x.unsqueeze(1)

# sample observed values and add some synthetic noise
train_y = torch.sin(train_x * (2 * np.pi)) + 0.15 * torch.randn_like(train_x)

Step 2: initialization/training, works just fine

model = botorch.models.SingleTaskGP(
    train_X=train_x, train_Y=train_y, input_transform=Normalize(1, transform_on_eval=True)
)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
botorch.fit.fit_gpytorch_mll(mll)

Step 3: condition

new_x = torch.FloatTensor(np.array([1.25, 1.5]).reshape(-1, 1))
new_y = torch.FloatTensor(np.array([-1.0, -2.0]).reshape(-1, 1))
model = model.condition_on_observations(new_x, new_y)

Step 4: attempt retraining to further tune hyper parameters/length scales and whatnot

mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
botorch.fit.fit_gpytorch_mll(mll)  # fails

Stack trace/error message

MDNotImplementedError                     Traceback (most recent call last)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/utils/dispatcher.py:88, in Dispatcher.__call__(self, *args, **kwargs)
     87 try:
---> 88     return func(*args, **kwargs)
     89 except MDNotImplementedError:
     90     # Traverses registered methods in order, yields whenever a match is found

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/fit.py:320, in _fit_multioutput_independent(mll, _, __, sequential, **kwargs)
    315 if (  # incompatible models
    316     not sequential
    317     or mll.model.num_outputs == 1
    318     or mll.likelihood is not getattr(mll.model, "likelihood", None)
    319 ):
--> 320     raise MDNotImplementedError  # defer to generic
    322 # TODO: Unpacking of OutcomeTransforms not yet supported. Targets are often
    323 # pre-transformed in __init__, so try fitting with outcome_transform hidden

MDNotImplementedError: 

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Input In [20], in <cell line: 2>()
      1 mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
----> 2 botorch.fit.fit_gpytorch_mll(mll)

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/fit.py:114, in fit_gpytorch_mll(mll, optimizer, optimizer_kwargs, **kwargs)
    111 if optimizer is not None:  # defer to per-method defaults
    112     kwargs["optimizer"] = optimizer
--> 114 return dispatcher(
    115     mll,
    116     type(mll.likelihood),
    117     type(mll.model),
    118     optimizer_kwargs=optimizer_kwargs,
    119     **kwargs,
    120 )

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/utils/dispatcher.py:95, in Dispatcher.__call__(self, *args, **kwargs)
     93 for func in funcs:
     94     try:
---> 95         return func(*args, **kwargs)
     96     except MDNotImplementedError:
     97         pass

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/fit.py:240, in _fit_fallback(mll, _, __, optimizer, optimizer_kwargs, max_attempts, warning_filter, caught_exception_types, **ignore)
    238 with catch_warnings(record=True) as warning_list, debug(True):
    239     simplefilter("always", category=OptimizationWarning)
--> 240     mll, _ = optimizer(mll, **optimizer_kwargs)
    242 # Resolve warning messages and determine whether or not to retry
    243 done = True

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/fit.py:142, in fit_gpytorch_scipy(mll, bounds, method, options, track_iterations, approx_mll, scipy_objective, module_to_array_func, module_from_array_func)
    140 cb = store_iteration if track_iterations else None
    141 with gpt_settings.fast_computations(log_prob=approx_mll):
--> 142     res = minimize(
    143         scipy_objective,
    144         x0,
    145         args=(mll, property_dict),
    146         bounds=bounds,
    147         method=method,
    148         jac=True,
    149         options=options,
    150         callback=cb,
    151     )
    152     iterations = []
    153     if track_iterations:

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_minimize.py:692, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    689     res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
    690                              **options)
    691 elif meth == 'l-bfgs-b':
--> 692     res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
    693                            callback=callback, **options)
    694 elif meth == 'tnc':
    695     res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
    696                         **options)

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_lbfgsb_py.py:308, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
    305     else:
    306         iprint = disp
--> 308 sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps,
    309                               bounds=new_bounds,
    310                               finite_diff_rel_step=finite_diff_rel_step)
    312 func_and_grad = sf.fun_and_grad
    314 fortran_int = _lbfgsb.types.intvar.dtype

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_optimize.py:263, in _prepare_scalar_function(fun, x0, jac, args, bounds, epsilon, finite_diff_rel_step, hess)
    259     bounds = (-np.inf, np.inf)
    261 # ScalarFunction caches. Reuse of fun(x) during grad
    262 # calculation reduces overall function evaluations.
--> 263 sf = ScalarFunction(fun, x0, args, grad, hess,
    264                     finite_diff_rel_step, bounds, epsilon=epsilon)
    266 return sf

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:158, in ScalarFunction.__init__(self, fun, x0, args, grad, hess, finite_diff_rel_step, finite_diff_bounds, epsilon)
    155     self.f = fun_wrapped(self.x)
    157 self._update_fun_impl = update_fun
--> 158 self._update_fun()
    160 # Gradient evaluation
    161 if callable(grad):

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:251, in ScalarFunction._update_fun(self)
    249 def _update_fun(self):
    250     if not self.f_updated:
--> 251         self._update_fun_impl()
    252         self.f_updated = True

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:155, in ScalarFunction.__init__.<locals>.update_fun()
    154 def update_fun():
--> 155     self.f = fun_wrapped(self.x)

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:137, in ScalarFunction.__init__.<locals>.fun_wrapped(x)
    133 self.nfev += 1
    134 # Send a copy because the user may overwrite it.
    135 # Overwriting results in undefined behaviour because
    136 # fun(self.x) will change self.x, with the two no longer linked.
--> 137 fx = fun(np.copy(x), *args)
    138 # Make sure the function returns a true scalar
    139 if not np.isscalar(fx):

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_optimize.py:76, in MemoizeJac.__call__(self, x, *args)
     74 def __call__(self, x, *args):
     75     """ returns the the function value """
---> 76     self._compute_if_needed(x, *args)
     77     return self._value

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_optimize.py:70, in MemoizeJac._compute_if_needed(self, x, *args)
     68 if not np.all(x == self.x) or self._value is None or self.jac is None:
     69     self.x = np.asarray(x).copy()
---> 70     fg = self.fun(x, *args)
     71     self.jac = fg[1]
     72     self._value = fg[0]

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/utils.py:227, in _scipy_objective_and_grad(x, mll, property_dict)
    225     loss = -mll(*args).sum()
    226 except RuntimeError as e:
--> 227     return _handle_numerical_errors(error=e, x=x)
    228 loss.backward()
    230 i = 0

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/utils.py:256, in _handle_numerical_errors(error, x)
    250 if (
    251     isinstance(error, NanError)
    252     or "singular" in error_message  # old pytorch message
    253     or "input is not positive-definite" in error_message  # since pytorch #63864
    254 ):
    255     return float("nan"), np.full_like(x, "nan")
--> 256 raise error

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/utils.py:225, in _scipy_objective_and_grad(x, mll, property_dict)
    223     output = mll.model(*train_inputs)
    224     args = [output, train_targets] + _get_extra_mll_args(mll)
--> 225     loss = -mll(*args).sum()
    226 except RuntimeError as e:
    227     return _handle_numerical_errors(error=e, x=x)

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/gpytorch/module.py:30, in Module.__call__(self, *inputs, **kwargs)
     29 def __call__(self, *inputs, **kwargs):
---> 30     outputs = self.forward(*inputs, **kwargs)
     31     if isinstance(outputs, list):
     32         return [_validate_module_outputs(output) for output in outputs]

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:64, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params)
     62 # Get the log prob of the marginal distribution
     63 output = self.likelihood(function_dist, *params)
---> 64 res = output.log_prob(target)
     65 res = self._add_other_terms(res, params)
     67 # Scale by the amount of data we have

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/gpytorch/distributions/multivariate_normal.py:147, in MultivariateNormal.log_prob(self, value)
    145 def log_prob(self, value):
    146     if settings.fast_computations.log_prob.off():
--> 147         return super().log_prob(value)
    149     if self._validate_args:
    150         self._validate_sample(value)

File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/torch/distributions/multivariate_normal.py:211, in MultivariateNormal.log_prob(self, value)
    209 if self._validate_args:
    210     self._validate_sample(value)
--> 211 diff = value - self.loc
    212 M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
    213 half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)

RuntimeError: The size of tensor a (17) must match the size of tensor b (15) at non-singleton dimension 0

Expected Behavior

I think the second training procedure is supposed to work, right? It would seem sensible that Normalize would be updated with the new training information as passed during conditioning.

System information

BoTorch version: 0.7.2
GPyTorch version: 1.9.0
Torch version: 1.12.0
Computer OS: Mac M1 Max OS version 12.5.1

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions