Skip to content

[Bug] An exception should be raised when the training data has requires_grad=True #2253

Open
@rexxy-sasori

Description

@rexxy-sasori

🐛 Bug

The train_X and train_Y that go into SingleTaskGP will lead to a failing fit_gpytorch_mll if they have require_grad=True, i.e. grad_fn is not None. The error goes away when the flag require_grad=False

To reproduce

** Code snippet to reproduce **

# Your code goes here
# Please make sure it does not require any external dependencies
import botorch, gpytorch, torch
from botorch.models import FixedNoiseGP, ModelListGP, SingleTaskGP
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from botorch import fit_gpytorch_mll

print(botorch.__version__)
print(gpytorch.__version__)
print(torch.__version__)

train_x = torch.randn(30, 6, requires_grad=True)

train_obj = 3*train_x + train_x**2
train_con = 4*train_x - train_x**3

print(train_obj.grad_fn)
print(train_con.grad_fn)

model_obj = SingleTaskGP(train_x, train_obj).to(train_x)
model_con = SingleTaskGP(train_x, train_con).to(train_x)

model = ModelListGP(model_obj, model_con)
mll = SumMarginalLogLikelihood(model.likelihood, model)

fit_gpytorch_mll(mll)

** Stack trace/error message **

RuntimeError                              Traceback (most recent call last)
Cell In[55], line 25
     22 model = ModelListGP(model_obj, model_con)
     23 mll = SumMarginalLogLikelihood(model.likelihood, model)
---> 25 fit_gpytorch_mll(mll)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:105, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    102 if optimizer is not None:  # defer to per-method defaults
    103     kwargs["optimizer"] = optimizer
--> 105 return FitGPyTorchMLL(
    106     mll,
    107     type(mll.likelihood),
    108     type(mll.model),
    109     closure=closure,
    110     closure_kwargs=closure_kwargs,
    111     optimizer_kwargs=optimizer_kwargs,
    112     **kwargs,
    113 )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:305, in _fit_list(mll, _, __, **kwargs)
    303 mll.train()
    304 for sub_mll in mll.mlls:
--> 305     fit_gpytorch_mll(sub_mll, **kwargs)
    307 return mll.eval() if not any(sub_mll.training for sub_mll in mll.mlls) else mll

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:105, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    102 if optimizer is not None:  # defer to per-method defaults
    103     kwargs["optimizer"] = optimizer
--> 105 return FitGPyTorchMLL(
    106     mll,
    107     type(mll.likelihood),
    108     type(mll.model),
    109     closure=closure,
    110     closure_kwargs=closure_kwargs,
    111     optimizer_kwargs=optimizer_kwargs,
    112     **kwargs,
    113 )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:252, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, warning_handler, caught_exception_types, **ignore)
    250 with catch_warnings(record=True) as warning_list, debug(True):
    251     simplefilter("always", category=OptimizationWarning)
--> 252     optimizer(mll, closure=closure, **optimizer_kwargs)
    254 # Resolved warnings and determine whether or not to retry
    255 done = True

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/fit.py:92, in fit_gpytorch_mll_scipy(mll, parameters, bounds, closure, closure_kwargs, method, options, callback, timeout_sec)
     89 if closure_kwargs is not None:
     90     closure = partial(closure, **closure_kwargs)
---> 92 result = scipy_minimize(
     93     closure=closure,
     94     parameters=parameters,
     95     bounds=bounds,
     96     method=method,
     97     options=options,
     98     callback=callback,
     99     timeout_sec=timeout_sec,
    100 )
    101 if result.status != OptimizationStatus.SUCCESS:
    102     warn(
    103         f"`scipy_minimize` terminated with status {result.status}, displaying"
    104         f" original message from `scipy.optimize.minimize`: {result.message}",
    105         OptimizationWarning,
    106     )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/core.py:109, in scipy_minimize(closure, parameters, bounds, callback, x0, method, options, timeout_sec)
    101         result = OptimizationResult(
    102             step=next(call_counter),
    103             fval=float(wrapped_closure(x)[0]),
    104             status=OptimizationStatus.RUNNING,
    105             runtime=monotonic() - start_time,
    106         )
    107         return callback(parameters, result)  # pyre-ignore [29]
--> 109 raw = minimize_with_timeout(
    110     wrapped_closure,
    111     wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
    112     jac=True,
    113     bounds=bounds_np,
    114     method=method,
    115     options=options,
    116     callback=wrapped_callback,
    117     timeout_sec=timeout_sec,
    118 )
    120 # Post-processing and outcome handling
    121 wrapped_closure.state = asarray(raw.x)  # set parameter state to optimal values

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/utils/timeout.py:80, in minimize_with_timeout(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options, timeout_sec)
     77     wrapped_callback = callback
     79 try:
---> 80     return optimize.minimize(
     81         fun=fun,
     82         x0=x0,
     83         args=args,
     84         method=method,
     85         jac=jac,
     86         hess=hess,
     87         hessp=hessp,
     88         bounds=bounds,
     89         constraints=constraints,
     90         tol=tol,
     91         callback=wrapped_callback,
     92         options=options,
     93     )
     94 except OptimizationTimeoutError as e:
     95     msg = f"Optimization timed out after {e.runtime} seconds."

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_minimize.py:710, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    707     res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
    708                              **options)
    709 elif meth == 'l-bfgs-b':
--> 710     res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
    711                            callback=callback, **options)
    712 elif meth == 'tnc':
    713     res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
    714                         **options)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_lbfgsb_py.py:365, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
    359 task_str = task.tobytes()
    360 if task_str.startswith(b'FG'):
    361     # The minimization routine wants f and g at the current x.
    362     # Note that interruptions due to maxfun are postponed
    363     # until the completion of the current minimization iteration.
    364     # Overwrite f and g:
--> 365     f, g = func_and_grad(x)
    366 elif task_str.startswith(b'NEW_X'):
    367     # new iteration
    368     n_iterations += 1

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:285, in ScalarFunction.fun_and_grad(self, x)
    283 if not np.array_equal(x, self.x):
    284     self._update_x_impl(x)
--> 285 self._update_fun()
    286 self._update_grad()
    287 return self.f, self.g

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:251, in ScalarFunction._update_fun(self)
    249 def _update_fun(self):
    250     if not self.f_updated:
--> 251         self._update_fun_impl()
    252         self.f_updated = True

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:155, in ScalarFunction.__init__.<locals>.update_fun()
    154 def update_fun():
--> 155     self.f = fun_wrapped(self.x)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:137, in ScalarFunction.__init__.<locals>.fun_wrapped(x)
    133 self.nfev += 1
    134 # Send a copy because the user may overwrite it.
    135 # Overwriting results in undefined behaviour because
    136 # fun(self.x) will change self.x, with the two no longer linked.
--> 137 fx = fun(np.copy(x), *args)
    138 # Make sure the function returns a true scalar
    139 if not np.isscalar(fx):

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_optimize.py:77, in MemoizeJac.__call__(self, x, *args)
     75 def __call__(self, x, *args):
     76     """ returns the function value """
---> 77     self._compute_if_needed(x, *args)
     78     return self._value

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_optimize.py:71, in MemoizeJac._compute_if_needed(self, x, *args)
     69 if not np.all(x == self.x) or self._value is None or self.jac is None:
     70     self.x = np.asarray(x).copy()
---> 71     fg = self.fun(x, *args)
     72     self.jac = fg[1]
     73     self._value = fg[0]

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/core.py:160, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
    158         index += size
    159 except RuntimeError as e:
--> 160     value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64)
    162 return value, grads

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/utils/common.py:52, in _handle_numerical_errors(error, x, dtype)
     50     _dtype = x.dtype if dtype is None else dtype
     51     return np.full((), "nan", dtype=_dtype), np.full_like(x, "nan", dtype=_dtype)
---> 52 raise error

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/core.py:150, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
    147     self.state = state
    149 try:
--> 150     value_tensor, grad_tensors = self.closure(**kwargs)
    151     value = self.as_array(value_tensor)
    152     grads = self._get_gradient_ndarray(fill_value=self.fill_value)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/core.py:66, in ForwardBackwardClosure.__call__(self, **kwargs)
     64 values = self.forward(**kwargs)
     65 value = values if self.reducer is None else self.reducer(values)
---> 66 self.backward(value)
     68 grads = tuple(param.grad for param in self.parameters.values())
     69 if self.callback:

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/torch/_tensor.py:522, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    512 if has_torch_function_unary(self):
    513     return handle_torch_function(
    514         Tensor.backward,
    515         (self,),
   (...)
    520         inputs=inputs,
    521     )
--> 522 torch.autograd.backward(
    523     self, gradient, retain_graph, create_graph, inputs=inputs
    524 )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/torch/autograd/__init__.py:266, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    261     retain_graph = create_graph
    263 # The reason we repeat the same comment below is that
    264 # some Python versions print out the first line of a multi-line function
    265 # calls in the traceback and some print out the last line
--> 266 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    267     tensors,
    268     grad_tensors_,
    269     retain_graph,
    270     create_graph,
    271     inputs,
    272     allow_unreachable=True,
    273     accumulate_grad=True,
    274 )

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Expected Behavior

An error message is thrown:

"Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward"

System information

Please complete the following information:

  • BoTorch 0.10.1.dev16+g3e34a4fc
  • GPyTorch 1.12.dev28+g392dd41e
  • PyTorch 2.2.1
  • MacOS 14.4 (23E214) Sonoma on M3 Pro

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions