Description
🚀 Feature Request
Now that fit_gpytorch_mll
exists using multiple dispatch, it seems like it'd be fairly straightforward to support minibatch training by registering a fit_gpytorch_torch_stochastic
or similar as the optimizer for _ApproximateMarginalLogLikelihood
mlls.
Motivation
Is your feature request related to a problem? Please describe.
As far as I can tell browsing the code, running fit_gpytorch_mll
on an ApproximateGPyTorchModel
would just use full batch training. As a result, we have (e.g., for latent space optimization tasks) typically been brewing our own GPyTorch models + training code still, despite the existence of ApproximateGPyTorchModel
. We're planning on submitting a PR with a latent space bayesopt tutorial, but I'd like it to be more BoTorch-y than it currently is -- right now the actual model handling is entirely outside of BoTorch for this reason.
Pitch
Describe the solution you'd like
- Write
fit_gpytorch_torch_stochastic
inbotorch.optim.fit
that does minibatch training with a user specified batch size. For now, I was thinking this can just make a standardDataLoader
over the train targets and inputs -- handling the case wheretrain_targets
is actually a tuple might require more thought if we wanted to support that out of the gate.maxiter
in the stopping critereon would refer to a number of epochs of training. - Register
fit_gpytorch_torch_stochastic
as the default optimizer via a_fit_approximategp_stochastic
inbotorch.fit
to the dispatcher for(_ApproximateMarginalLogLikelihood, Likelihood, ApproximateGPyTorchModel)
. - (Possibly breaking) As described above, this would leave it to the user to decide to do full batch optimization, either by specifying
fit_gpytorch_torch
manually as the optimizer or (equivalently with negligible overhead) specifying the batch size to be the full N. One solution might be to just call the fallback fit if a minibatch size / optimizer isn't specified by the user? On the other hand, in the long run, it probably makes sense to assume by default that the user wants to do some kind of stochastic optimization if they're going to the trouble of using a variational approximate GP specifically rather than just e.g. an inducing point kernel on anExactGP
.
Are you willing to open a pull request? (See CONTRIBUTING)
Yeah