Open
Description
🐛 Bug
The train_X and train_Y that go into SingleTaskGP will lead to a failing fit_gpytorch_mll if they have require_grad=True, i.e. grad_fn is not None. The error goes away when the flag require_grad=False
To reproduce
** Code snippet to reproduce **
# Your code goes here
# Please make sure it does not require any external dependencies
import botorch, gpytorch, torch
from botorch.models import FixedNoiseGP, ModelListGP, SingleTaskGP
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from botorch import fit_gpytorch_mll
print(botorch.__version__)
print(gpytorch.__version__)
print(torch.__version__)
train_x = torch.randn(30, 6, requires_grad=True)
train_obj = 3*train_x + train_x**2
train_con = 4*train_x - train_x**3
print(train_obj.grad_fn)
print(train_con.grad_fn)
model_obj = SingleTaskGP(train_x, train_obj).to(train_x)
model_con = SingleTaskGP(train_x, train_con).to(train_x)
model = ModelListGP(model_obj, model_con)
mll = SumMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)
** Stack trace/error message **
RuntimeError Traceback (most recent call last)
Cell In[55], line 25
22 model = ModelListGP(model_obj, model_con)
23 mll = SumMarginalLogLikelihood(model.likelihood, model)
---> 25 fit_gpytorch_mll(mll)
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:105, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
102 if optimizer is not None: # defer to per-method defaults
103 kwargs["optimizer"] = optimizer
--> 105 return FitGPyTorchMLL(
106 mll,
107 type(mll.likelihood),
108 type(mll.model),
109 closure=closure,
110 closure_kwargs=closure_kwargs,
111 optimizer_kwargs=optimizer_kwargs,
112 **kwargs,
113 )
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
91 func = self.__getitem__(types=types)
92 try:
---> 93 return func(*args, **kwargs)
94 except MDNotImplementedError:
95 # Traverses registered methods in order, yields whenever a match is found
96 funcs = self.dispatch_iter(*types)
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:305, in _fit_list(mll, _, __, **kwargs)
303 mll.train()
304 for sub_mll in mll.mlls:
--> 305 fit_gpytorch_mll(sub_mll, **kwargs)
307 return mll.eval() if not any(sub_mll.training for sub_mll in mll.mlls) else mll
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:105, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
102 if optimizer is not None: # defer to per-method defaults
103 kwargs["optimizer"] = optimizer
--> 105 return FitGPyTorchMLL(
106 mll,
107 type(mll.likelihood),
108 type(mll.model),
109 closure=closure,
110 closure_kwargs=closure_kwargs,
111 optimizer_kwargs=optimizer_kwargs,
112 **kwargs,
113 )
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
91 func = self.__getitem__(types=types)
92 try:
---> 93 return func(*args, **kwargs)
94 except MDNotImplementedError:
95 # Traverses registered methods in order, yields whenever a match is found
96 funcs = self.dispatch_iter(*types)
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:252, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, warning_handler, caught_exception_types, **ignore)
250 with catch_warnings(record=True) as warning_list, debug(True):
251 simplefilter("always", category=OptimizationWarning)
--> 252 optimizer(mll, closure=closure, **optimizer_kwargs)
254 # Resolved warnings and determine whether or not to retry
255 done = True
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/fit.py:92, in fit_gpytorch_mll_scipy(mll, parameters, bounds, closure, closure_kwargs, method, options, callback, timeout_sec)
89 if closure_kwargs is not None:
90 closure = partial(closure, **closure_kwargs)
---> 92 result = scipy_minimize(
93 closure=closure,
94 parameters=parameters,
95 bounds=bounds,
96 method=method,
97 options=options,
98 callback=callback,
99 timeout_sec=timeout_sec,
100 )
101 if result.status != OptimizationStatus.SUCCESS:
102 warn(
103 f"`scipy_minimize` terminated with status {result.status}, displaying"
104 f" original message from `scipy.optimize.minimize`: {result.message}",
105 OptimizationWarning,
106 )
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/core.py:109, in scipy_minimize(closure, parameters, bounds, callback, x0, method, options, timeout_sec)
101 result = OptimizationResult(
102 step=next(call_counter),
103 fval=float(wrapped_closure(x)[0]),
104 status=OptimizationStatus.RUNNING,
105 runtime=monotonic() - start_time,
106 )
107 return callback(parameters, result) # pyre-ignore [29]
--> 109 raw = minimize_with_timeout(
110 wrapped_closure,
111 wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
112 jac=True,
113 bounds=bounds_np,
114 method=method,
115 options=options,
116 callback=wrapped_callback,
117 timeout_sec=timeout_sec,
118 )
120 # Post-processing and outcome handling
121 wrapped_closure.state = asarray(raw.x) # set parameter state to optimal values
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/utils/timeout.py:80, in minimize_with_timeout(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options, timeout_sec)
77 wrapped_callback = callback
79 try:
---> 80 return optimize.minimize(
81 fun=fun,
82 x0=x0,
83 args=args,
84 method=method,
85 jac=jac,
86 hess=hess,
87 hessp=hessp,
88 bounds=bounds,
89 constraints=constraints,
90 tol=tol,
91 callback=wrapped_callback,
92 options=options,
93 )
94 except OptimizationTimeoutError as e:
95 msg = f"Optimization timed out after {e.runtime} seconds."
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_minimize.py:710, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
707 res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
708 **options)
709 elif meth == 'l-bfgs-b':
--> 710 res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
711 callback=callback, **options)
712 elif meth == 'tnc':
713 res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
714 **options)
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_lbfgsb_py.py:365, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
359 task_str = task.tobytes()
360 if task_str.startswith(b'FG'):
361 # The minimization routine wants f and g at the current x.
362 # Note that interruptions due to maxfun are postponed
363 # until the completion of the current minimization iteration.
364 # Overwrite f and g:
--> 365 f, g = func_and_grad(x)
366 elif task_str.startswith(b'NEW_X'):
367 # new iteration
368 n_iterations += 1
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:285, in ScalarFunction.fun_and_grad(self, x)
283 if not np.array_equal(x, self.x):
284 self._update_x_impl(x)
--> 285 self._update_fun()
286 self._update_grad()
287 return self.f, self.g
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:251, in ScalarFunction._update_fun(self)
249 def _update_fun(self):
250 if not self.f_updated:
--> 251 self._update_fun_impl()
252 self.f_updated = True
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:155, in ScalarFunction.__init__.<locals>.update_fun()
154 def update_fun():
--> 155 self.f = fun_wrapped(self.x)
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:137, in ScalarFunction.__init__.<locals>.fun_wrapped(x)
133 self.nfev += 1
134 # Send a copy because the user may overwrite it.
135 # Overwriting results in undefined behaviour because
136 # fun(self.x) will change self.x, with the two no longer linked.
--> 137 fx = fun(np.copy(x), *args)
138 # Make sure the function returns a true scalar
139 if not np.isscalar(fx):
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_optimize.py:77, in MemoizeJac.__call__(self, x, *args)
75 def __call__(self, x, *args):
76 """ returns the function value """
---> 77 self._compute_if_needed(x, *args)
78 return self._value
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_optimize.py:71, in MemoizeJac._compute_if_needed(self, x, *args)
69 if not np.all(x == self.x) or self._value is None or self.jac is None:
70 self.x = np.asarray(x).copy()
---> 71 fg = self.fun(x, *args)
72 self.jac = fg[1]
73 self._value = fg[0]
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/core.py:160, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
158 index += size
159 except RuntimeError as e:
--> 160 value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64)
162 return value, grads
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/utils/common.py:52, in _handle_numerical_errors(error, x, dtype)
50 _dtype = x.dtype if dtype is None else dtype
51 return np.full((), "nan", dtype=_dtype), np.full_like(x, "nan", dtype=_dtype)
---> 52 raise error
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/core.py:150, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
147 self.state = state
149 try:
--> 150 value_tensor, grad_tensors = self.closure(**kwargs)
151 value = self.as_array(value_tensor)
152 grads = self._get_gradient_ndarray(fill_value=self.fill_value)
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/core.py:66, in ForwardBackwardClosure.__call__(self, **kwargs)
64 values = self.forward(**kwargs)
65 value = values if self.reducer is None else self.reducer(values)
---> 66 self.backward(value)
68 grads = tuple(param.grad for param in self.parameters.values())
69 if self.callback:
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/torch/_tensor.py:522, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
512 if has_torch_function_unary(self):
513 return handle_torch_function(
514 Tensor.backward,
515 (self,),
(...)
520 inputs=inputs,
521 )
--> 522 torch.autograd.backward(
523 self, gradient, retain_graph, create_graph, inputs=inputs
524 )
File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/torch/autograd/__init__.py:266, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
261 retain_graph = create_graph
263 # The reason we repeat the same comment below is that
264 # some Python versions print out the first line of a multi-line function
265 # calls in the traceback and some print out the last line
--> 266 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
267 tensors,
268 grad_tensors_,
269 retain_graph,
270 create_graph,
271 inputs,
272 allow_unreachable=True,
273 accumulate_grad=True,
274 )
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
Expected Behavior
An error message is thrown:
"Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward"
System information
Please complete the following information:
- BoTorch 0.10.1.dev16+g3e34a4fc
- GPyTorch 1.12.dev28+g392dd41e
- PyTorch 2.2.1
- MacOS 14.4 (23E214) Sonoma on M3 Pro