[Bug] An exception should be raised when the training data has requires_grad=True #2253




🐛 Bug

The train_X and train_Y that go into SingleTaskGP will lead to a failing fit_gpytorch_mll if they have require_grad=True, i.e. grad_fn is not None. The error goes away when the flag require_grad=False

To reproduce

** Code snippet to reproduce **

# Your code goes here
# Please make sure it does not require any external dependencies
import botorch, gpytorch, torch
from botorch.models import FixedNoiseGP, ModelListGP, SingleTaskGP
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from botorch import fit_gpytorch_mll


train_x = torch.randn(30, 6, requires_grad=True)

train_obj = 3*train_x + train_x**2
train_con = 4*train_x - train_x**3


model_obj = SingleTaskGP(train_x, train_obj).to(train_x)
model_con = SingleTaskGP(train_x, train_con).to(train_x)

model = ModelListGP(model_obj, model_con)
mll = SumMarginalLogLikelihood(model.likelihood, model)


** Stack trace/error message **

RuntimeError                              Traceback (most recent call last)
Cell In[55], line 25
     22 model = ModelListGP(model_obj, model_con)
     23 mll = SumMarginalLogLikelihood(model.likelihood, model)
---> 25 fit_gpytorch_mll(mll)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    102 if optimizer is not None:  # defer to per-method defaults
    103     kwargs["optimizer"] = optimizer
--> 105 return FitGPyTorchMLL(
    106     mll,
    107     type(mll.likelihood),
    108     type(mll.model),
    109     closure=closure,
    110     closure_kwargs=closure_kwargs,
    111     optimizer_kwargs=optimizer_kwargs,
    112     **kwargs,
    113 )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/utils/, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/, in _fit_list(mll, _, __, **kwargs)
    303 mll.train()
    304 for sub_mll in mll.mlls:
--> 305     fit_gpytorch_mll(sub_mll, **kwargs)
    307 return mll.eval() if not any( for sub_mll in mll.mlls) else mll

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    102 if optimizer is not None:  # defer to per-method defaults
    103     kwargs["optimizer"] = optimizer
--> 105 return FitGPyTorchMLL(
    106     mll,
    107     type(mll.likelihood),
    108     type(mll.model),
    109     closure=closure,
    110     closure_kwargs=closure_kwargs,
    111     optimizer_kwargs=optimizer_kwargs,
    112     **kwargs,
    113 )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/utils/, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, warning_handler, caught_exception_types, **ignore)
    250 with catch_warnings(record=True) as warning_list, debug(True):
    251     simplefilter("always", category=OptimizationWarning)
--> 252     optimizer(mll, closure=closure, **optimizer_kwargs)
    254 # Resolved warnings and determine whether or not to retry
    255 done = True

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/, in fit_gpytorch_mll_scipy(mll, parameters, bounds, closure, closure_kwargs, method, options, callback, timeout_sec)
     89 if closure_kwargs is not None:
     90     closure = partial(closure, **closure_kwargs)
---> 92 result = scipy_minimize(
     93     closure=closure,
     94     parameters=parameters,
     95     bounds=bounds,
     96     method=method,
     97     options=options,
     98     callback=callback,
     99     timeout_sec=timeout_sec,
    100 )
    101 if result.status != OptimizationStatus.SUCCESS:
    102     warn(
    103         f"`scipy_minimize` terminated with status {result.status}, displaying"
    104         f" original message from `scipy.optimize.minimize`: {result.message}",
    105         OptimizationWarning,
    106     )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/, in scipy_minimize(closure, parameters, bounds, callback, x0, method, options, timeout_sec)
    101         result = OptimizationResult(
    102             step=next(call_counter),
    103             fval=float(wrapped_closure(x)[0]),
    104             status=OptimizationStatus.RUNNING,
    105             runtime=monotonic() - start_time,
    106         )
    107         return callback(parameters, result)  # pyre-ignore [29]
--> 109 raw = minimize_with_timeout(
    110     wrapped_closure,
    111     wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
    112     jac=True,
    113     bounds=bounds_np,
    114     method=method,
    115     options=options,
    116     callback=wrapped_callback,
    117     timeout_sec=timeout_sec,
    118 )
    120 # Post-processing and outcome handling
    121 wrapped_closure.state = asarray(raw.x)  # set parameter state to optimal values

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/utils/, in minimize_with_timeout(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options, timeout_sec)
     77     wrapped_callback = callback
     79 try:
---> 80     return optimize.minimize(
     81         fun=fun,
     82         x0=x0,
     83         args=args,
     84         method=method,
     85         jac=jac,
     86         hess=hess,
     87         hessp=hessp,
     88         bounds=bounds,
     89         constraints=constraints,
     90         tol=tol,
     91         callback=wrapped_callback,
     92         options=options,
     93     )
     94 except OptimizationTimeoutError as e:
     95     msg = f"Optimization timed out after {e.runtime} seconds."

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    707     res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
    708                              **options)
    709 elif meth == 'l-bfgs-b':
--> 710     res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
    711                            callback=callback, **options)
    712 elif meth == 'tnc':
    713     res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
    714                         **options)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
    359 task_str = task.tobytes()
    360 if task_str.startswith(b'FG'):
    361     # The minimization routine wants f and g at the current x.
    362     # Note that interruptions due to maxfun are postponed
    363     # until the completion of the current minimization iteration.
    364     # Overwrite f and g:
--> 365     f, g = func_and_grad(x)
    366 elif task_str.startswith(b'NEW_X'):
    367     # new iteration
    368     n_iterations += 1

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/, in ScalarFunction.fun_and_grad(self, x)
    283 if not np.array_equal(x, self.x):
    284     self._update_x_impl(x)
--> 285 self._update_fun()
    286 self._update_grad()
    287 return self.f, self.g

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/, in ScalarFunction._update_fun(self)
    249 def _update_fun(self):
    250     if not self.f_updated:
--> 251         self._update_fun_impl()
    252         self.f_updated = True

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/, in ScalarFunction.__init__.<locals>.update_fun()
    154 def update_fun():
--> 155     self.f = fun_wrapped(self.x)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/, in ScalarFunction.__init__.<locals>.fun_wrapped(x)
    133 self.nfev += 1
    134 # Send a copy because the user may overwrite it.
    135 # Overwriting results in undefined behaviour because
    136 # fun(self.x) will change self.x, with the two no longer linked.
--> 137 fx = fun(np.copy(x), *args)
    138 # Make sure the function returns a true scalar
    139 if not np.isscalar(fx):

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/, in MemoizeJac.__call__(self, x, *args)
     75 def __call__(self, x, *args):
     76     """ returns the function value """
---> 77     self._compute_if_needed(x, *args)
     78     return self._value

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/, in MemoizeJac._compute_if_needed(self, x, *args)
     69 if not np.all(x == self.x) or self._value is None or self.jac is None:
     70     self.x = np.asarray(x).copy()
---> 71     fg =, *args)
     72     self.jac = fg[1]
     73     self._value = fg[0]

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
    158         index += size
    159 except RuntimeError as e:
--> 160     value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64)
    162 return value, grads

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/utils/, in _handle_numerical_errors(error, x, dtype)
     50     _dtype = x.dtype if dtype is None else dtype
     51     return np.full((), "nan", dtype=_dtype), np.full_like(x, "nan", dtype=_dtype)
---> 52 raise error

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
    147     self.state = state
    149 try:
--> 150     value_tensor, grad_tensors = self.closure(**kwargs)
    151     value = self.as_array(value_tensor)
    152     grads = self._get_gradient_ndarray(fill_value=self.fill_value)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/, in ForwardBackwardClosure.__call__(self, **kwargs)
     64 values = self.forward(**kwargs)
     65 value = values if self.reducer is None else self.reducer(values)
---> 66 self.backward(value)
     68 grads = tuple(param.grad for param in self.parameters.values())
     69 if self.callback:

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/torch/, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    512 if has_torch_function_unary(self):
    513     return handle_torch_function(
    514         Tensor.backward,
    515         (self,),
    520         inputs=inputs,
    521     )
--> 522 torch.autograd.backward(
    523     self, gradient, retain_graph, create_graph, inputs=inputs
    524 )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/torch/autograd/, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    261     retain_graph = create_graph
    263 # The reason we repeat the same comment below is that
    264 # some Python versions print out the first line of a multi-line function
    265 # calls in the traceback and some print out the last line
--> 266 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    267     tensors,
    268     grad_tensors_,
    269     retain_graph,
    270     create_graph,
    271     inputs,
    272     allow_unreachable=True,
    273     accumulate_grad=True,
    274 )

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Expected Behavior

An error message is thrown:

"Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward"

System information

Please complete the following information:

  • BoTorch 0.10.1.dev16+g3e34a4fc
  • GPyTorch 1.12.dev28+g392dd41e
  • PyTorch 2.2.1
  • MacOS 14.4 (23E214) Sonoma on M3 Pro

