Skip to content

[Bug] Most Likely Heteroskedastic GP crashes because of NaN in noise_model #990

Open
@ArnoVel

Description

🐛 Bug

This problem occurred while I was trying to reuse a solution from pytorch/botorch#250 where more information can be found here.
I raised some questions related to the model itself and how it's implemented but this is not the main point of this post.

The problem is as follows:

  1. An issue related to the model: the MLHGP oscillates between different 'modes' for the noise distribution, and the model doesn't seem to get close to a fixed noise distribution, while the posterior distribution seems quite stable.
  2. Even though the noise model oscillates, after some iterations it crashes as a result of NaN values in the likelihood

To reproduce

A notebook showing that the error appears after a few iterations can be found here.

** Code snippet to reproduce **

One could also reproduce the bug by using the code from pytorch/botorch#250 and making sure the observed variance is detached from the graph.
The notebook I provided above comes with some code that essentially does this.

** Stack trace/error message **

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~/SJTU/research_code/TCEP/GP_scoring/gpRegressors.py in _iterate_em_like_procedure(self)
    139             hetero_mll.train()
--> 140             self.fit_model(hetero_mll)
    141         except Exception as e:

~/SJTU/research_code/TCEP/GP_scoring/gpRegressors.py in fit_model(self, mll)
     90     def fit_model(self, mll):
---> 91         botorch.fit.fit_gpytorch_model(mll)
     92 

~/.local/lib/python3.6/site-packages/botorch/fit.py in fit_gpytorch_model(mll, optimizer, **kwargs)
     97                 sample_all_priors(mll.model)
---> 98             mll, _ = optimizer(mll, track_iterations=False, **kwargs)
     99             if not any(issubclass(w.category, OptimizationWarning) for w in ws):

~/.local/lib/python3.6/site-packages/botorch/optim/fit.py in fit_gpytorch_scipy(mll, bounds, method, options, track_iterations)
    209         options=options,
--> 210         callback=cb,
    211     )

~/.local/lib/python3.6/site-packages/scipy/optimize/_minimize.py in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    599         return _minimize_lbfgsb(fun, x0, args, jac, bounds,
--> 600                                 callback=callback, **options)
    601     elif meth == 'tnc':

~/.local/lib/python3.6/site-packages/scipy/optimize/lbfgsb.py in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, **unknown_options)
    334             # Overwrite f and g:
--> 335             f, g = func_and_grad(x)
    336         elif task_str.startswith(b'NEW_X'):

~/.local/lib/python3.6/site-packages/scipy/optimize/lbfgsb.py in func_and_grad(x)
    284         def func_and_grad(x):
--> 285             f = fun(x, *args)
    286             g = jac(x, *args)

~/.local/lib/python3.6/site-packages/scipy/optimize/optimize.py in function_wrapper(*wrapper_args)
    326         ncalls[0] += 1
--> 327         return function(*(wrapper_args + args))
    328 

~/.local/lib/python3.6/site-packages/scipy/optimize/optimize.py in __call__(self, x, *args)
     64         self.x = numpy.asarray(x).copy()
---> 65         fg = self.fun(x, *args)
     66         self.jac = fg[1]

~/.local/lib/python3.6/site-packages/botorch/optim/fit.py in _scipy_objective_and_grad(x, mll, property_dict)
    267         else:
--> 268             raise e  # pragma: nocover
    269     loss.backward()

~/.local/lib/python3.6/site-packages/botorch/optim/fit.py in _scipy_objective_and_grad(x, mll, property_dict)
    262         args = [output, train_targets] + _get_extra_mll_args(mll)
--> 263         loss = -mll(*args).sum()
    264     except RuntimeError as e:

~/.local/lib/python3.6/site-packages/gpytorch/module.py in __call__(self, *inputs, **kwargs)
     21     def __call__(self, *inputs, **kwargs):
---> 22         outputs = self.forward(*inputs, **kwargs)
     23         if isinstance(outputs, list):

~/.local/lib/python3.6/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py in forward(self, output, target, *params)
     25         # Get the log prob of the marginal distribution
---> 26         output = self.likelihood(output, *params)
     27         res = output.log_prob(target)

~/.local/lib/python3.6/site-packages/gpytorch/likelihoods/likelihood.py in __call__(self, input, *params, **kwargs)
    122         elif isinstance(input, MultivariateNormal):
--> 123             return self.marginal(input, *params, **kwargs)
    124         # Error

~/.local/lib/python3.6/site-packages/gpytorch/likelihoods/gaussian_likelihood.py in marginal(self, function_dist, *params, **kwargs)
     46         mean, covar = function_dist.mean, function_dist.lazy_covariance_matrix
---> 47         noise_covar = self._shaped_noise_covar(mean.shape, *params, **kwargs)
     48         full_covar = covar + noise_covar

~/.local/lib/python3.6/site-packages/gpytorch/likelihoods/gaussian_likelihood.py in _shaped_noise_covar(self, base_shape, *params, **kwargs)
     38             shape = base_shape
---> 39         return self.noise_covar(*params, shape=shape, **kwargs)
     40 

~/.local/lib/python3.6/site-packages/gpytorch/module.py in __call__(self, *inputs, **kwargs)
     21     def __call__(self, *inputs, **kwargs):
---> 22         outputs = self.forward(*inputs, **kwargs)
     23         if isinstance(outputs, list):

~/.local/lib/python3.6/site-packages/gpytorch/likelihoods/noise_models.py in forward(self, batch_shape, shape, noise, *params)
    150             else:
--> 151                 output = self.noise_model(*params)
    152         self.noise_model.train(training)

~/.local/lib/python3.6/site-packages/gpytorch/models/exact_gp.py in __call__(self, *args, **kwargs)
    290             with settings._use_eval_tolerance():
--> 291                 predictive_mean, predictive_covar = self.prediction_strategy.exact_prediction(full_mean, full_covar)
    292 

~/.local/lib/python3.6/site-packages/gpytorch/models/exact_prediction_strategies.py in exact_prediction(self, joint_mean, joint_covar)
    288         return (
--> 289             self.exact_predictive_mean(test_mean, test_train_covar),
    290             self.exact_predictive_covar(test_test_covar, test_train_covar),

~/.local/lib/python3.6/site-packages/gpytorch/models/exact_prediction_strategies.py in exact_predictive_mean(self, test_mean, test_train_covar)
    306         # GP, and using addmv requires you to delazify test_train_covar, which is obviously a huge no-no!
--> 307         res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1)
    308         res = res + test_mean

~/.local/lib/python3.6/site-packages/gpytorch/utils/memoize.py in g(self, *args, **kwargs)
     33         if not is_in_cache(self, cache_name):
---> 34             add_to_cache(self, cache_name, method(self, *args, **kwargs))
     35         return get_from_cache(self, cache_name)

~/.local/lib/python3.6/site-packages/gpytorch/models/exact_prediction_strategies.py in mean_cache(self)
    260         train_labels_offset = (self.train_labels - train_mean).unsqueeze(-1)
--> 261         mean_cache = train_train_covar.inv_matmul(train_labels_offset).squeeze(-1)
    262 

~/.local/lib/python3.6/site-packages/gpytorch/lazy/lazy_tensor.py in inv_matmul(self, right_tensor, left_tensor)
    927                 right_tensor,
--> 928                 *self.representation(),
    929             )

~/.local/lib/python3.6/site-packages/gpytorch/functions/_inv_matmul.py in forward(ctx, representation_tree, has_left, *args)
     45         else:
---> 46             solves = _solve(lazy_tsr, right_tensor)
     47             res = solves

~/.local/lib/python3.6/site-packages/gpytorch/functions/_inv_matmul.py in _solve(lazy_tsr, rhs)
     13             preconditioner = lazy_tsr.detach()._inv_matmul_preconditioner()
---> 14         return lazy_tsr._solve(rhs, preconditioner)
     15 

~/.local/lib/python3.6/site-packages/gpytorch/lazy/lazy_tensor.py in _solve(self, rhs, preconditioner, num_tridiag)
    640             max_tridiag_iter=settings.max_lanczos_quadrature_iterations.value(),
--> 641             preconditioner=preconditioner,
    642         )

~/.local/lib/python3.6/site-packages/gpytorch/utils/linear_cg.py in linear_cg(matmul_closure, rhs, n_tridiag, tolerance, eps, stop_updating_after, max_iter, max_tridiag_iter, initial_guess, preconditioner)
    161     if not torch.equal(residual, residual):
--> 162         raise RuntimeError("NaNs encounterd when trying to perform matrix-vector multiplication")
    163 

RuntimeError: NaNs encounterd when trying to perform matrix-vector multiplication

Expected Behavior

Oscillation is somewhat expected, but the NaN values shouldn't appear as often as they do in my case.

System information

Please complete the following information:

  • botorch (0.1.4)
  • gpytorch (0.3.6)
  • torch (1.3.1)
  • Ubuntu 18.04

Metadata

Assignees

No one assigned

    Labels

    bugstabilityWhen models return NaNs and stuff

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions