Skip to content

Commit b6d71f1

Browse files
esantorellameta-codesync[bot]
authored andcommitted
Replace FitGPyTorchMLL dispatcher with isinstance checks (#3233)
Summary: Pull Request resolved: #3233 **Context**: See D96592835 for stack overview. **This PR**: Replace the FitGPyTorchMLL multiple-dispatch mechanism with simple `isinstance` checks in `fit_gpytorch_mll`, using the `custom_fit` method defined in D96592835. This reduces stack depth and makes the fitting code path easier to follow and debug. Changes: - fit_gpytorch_mll now uses isinstance checks to route to _fit_list (for SumMarginalLogLikelihood + ModelListGP), _fit_fallback_approximate (for _ApproximateMarginalLogLikelihood), or _fit_fallback (default). - Removed FitGPyTorchMLL dispatcher and its import of Dispatcher. - Converted RobustRelevancePursuitMixin from two FitGPyTorchMLL.register calls to a custom_fit method on the mixin class. - Simplified _fit_fallback, _fit_list, _fit_fallback_approximate signatures by removing unused type arguments that were only needed for dispatching. Reviewed By: saitcakmak Differential Revision: D96592852 fbshipit-source-id: 57e7e58a421fc055637f1e9fefaf8b8ad654a819
1 parent feb9088 commit b6d71f1

3 files changed

Lines changed: 163 additions & 228 deletions

File tree

botorch/fit.py

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from botorch.exceptions.warnings import OptimizationWarning
2020
from botorch.logging import logger
2121
from botorch.models import SingleTaskGP
22-
from botorch.models.approximate_gp import ApproximateGPyTorchModel
2322
from botorch.models.fully_bayesian import AbstractFullyBayesianSingleTaskGP
2423
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
2524
from botorch.models.map_saas import get_map_saas_model
@@ -39,8 +38,6 @@
3938
parameter_rollback_ctx,
4039
TensorCheckpoint,
4140
)
42-
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
43-
from gpytorch.likelihoods import Likelihood
4441
from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood
4542
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
4643
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
@@ -73,7 +70,6 @@ def _rethrow_warn(w: WarningMessage) -> bool:
7370
debug=_debug_warn,
7471
rethrow=_rethrow_warn,
7572
)
76-
FitGPyTorchMLL = Dispatcher("fit_gpytorch_mll", encoder=type_bypassing_encoder)
7773

7874

7975
def fit_gpytorch_mll(
@@ -86,18 +82,22 @@ def fit_gpytorch_mll(
8682
) -> MarginalLogLikelihood:
8783
r"""Clearing house for fitting models passed as GPyTorch MarginalLogLikelihoods.
8884
85+
If a model defines a ``custom_fit`` method, it will be called directly.
86+
Otherwise, a fit method is determined based on the types of the model and
87+
MLL.
88+
8989
Args:
9090
mll: A GPyTorch MarginalLogLikelihood instance.
9191
closure: Forward-backward closure for obtaining objective values and gradients.
9292
Responsible for setting parameters' ``grad`` attributes. If no closure is
9393
provided, one will be obtained by calling ``get_loss_closure_with_grads``.
9494
optimizer: User specified optimization algorithm. When ``optimizer is None``,
95-
this keyword argument is omitted when calling the dispatcher.
95+
this keyword argument is omitted when calling the underlying fit routine.
9696
closure_kwargs: Keyword arguments passed when calling ``closure``.
9797
optimizer_kwargs: A dictionary of keyword arguments passed when
9898
calling ``optimizer``.
99-
**kwargs: Keyword arguments passed down through the dispatcher to
100-
fit subroutines. Unexpected keywords are ignored.
99+
**kwargs: Keyword arguments passed to the underlying fit routine.
100+
Unexpected keywords are ignored.
101101
102102
Returns:
103103
The ``mll`` instance. If fitting succeeded, then ``mll`` will be in
@@ -116,22 +116,38 @@ def fit_gpytorch_mll(
116116
**kwargs,
117117
)
118118

119-
return FitGPyTorchMLL(
120-
mll,
121-
type(mll.likelihood),
122-
type(mll.model),
119+
if isinstance(mll, SumMarginalLogLikelihood) and isinstance(mll.model, ModelListGP):
120+
mll.train()
121+
for sub_mll in mll.mlls:
122+
fit_gpytorch_mll(
123+
mll=sub_mll,
124+
closure=closure,
125+
closure_kwargs=closure_kwargs,
126+
optimizer_kwargs=optimizer_kwargs,
127+
**kwargs,
128+
)
129+
return mll.eval() if not any(sub_mll.training for sub_mll in mll.mlls) else mll
130+
131+
if isinstance(mll, _ApproximateMarginalLogLikelihood):
132+
return _fit_fallback_approximate(
133+
mll=mll,
134+
closure=closure,
135+
closure_kwargs=closure_kwargs,
136+
optimizer_kwargs=optimizer_kwargs,
137+
**kwargs,
138+
)
139+
140+
return _fit_fallback(
141+
mll=mll,
123142
closure=closure,
124143
closure_kwargs=closure_kwargs,
125144
optimizer_kwargs=optimizer_kwargs,
126145
**kwargs,
127146
)
128147

129148

130-
@FitGPyTorchMLL.register(MarginalLogLikelihood, object, object)
131149
def _fit_fallback(
132150
mll: MarginalLogLikelihood,
133-
_: type[object],
134-
__: type[object],
135151
*,
136152
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
137153
optimizer: Callable = fit_gpytorch_mll_scipy,
@@ -272,35 +288,8 @@ def _fit_fallback(
272288
raise ModelFittingError("All attempts to fit the model have failed.")
273289

274290

275-
@FitGPyTorchMLL.register(SumMarginalLogLikelihood, object, ModelListGP)
276-
def _fit_list(
277-
mll: SumMarginalLogLikelihood,
278-
_: type[Likelihood],
279-
__: type[ModelListGP],
280-
**kwargs: Any,
281-
) -> SumMarginalLogLikelihood:
282-
r"""Fitting routine for lists of independent Gaussian processes.
283-
284-
Args:
285-
**kwargs: Passed to each of ``mll.mlls``.
286-
287-
Returns:
288-
The ``mll`` instance. If fitting succeeded for all of ``mll.mlls``,
289-
then ``mll`` will be in evaluation mode, i.e. ``mll.training == False``.
290-
Otherwise, ``mll`` will be in training mode.
291-
"""
292-
mll.train()
293-
for sub_mll in mll.mlls:
294-
fit_gpytorch_mll(sub_mll, **kwargs)
295-
296-
return mll.eval() if not any(sub_mll.training for sub_mll in mll.mlls) else mll
297-
298-
299-
@FitGPyTorchMLL.register(_ApproximateMarginalLogLikelihood, object, object)
300291
def _fit_fallback_approximate(
301292
mll: _ApproximateMarginalLogLikelihood,
302-
_: type[Likelihood],
303-
__: type[ApproximateGPyTorchModel],
304293
*,
305294
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
306295
data_loader: DataLoader | None = None,
@@ -342,7 +331,7 @@ def _fit_fallback_approximate(
342331
else fit_gpytorch_mll_torch
343332
)
344333

345-
return _fit_fallback(mll, _, __, closure=closure, optimizer=optimizer, **kwargs)
334+
return _fit_fallback(mll=mll, closure=closure, optimizer=optimizer, **kwargs)
346335

347336

348337
def fit_fully_bayesian_model_nuts(

botorch/models/robust_relevance_pursuit_model.py

Lines changed: 97 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838

3939
import torch
4040
from botorch.exceptions.errors import UnsupportedError
41-
from botorch.fit import FitGPyTorchMLL
4241
from botorch.models import SingleTaskGP
4342
from botorch.models.likelihoods.sparse_outlier_noise import (
4443
SparseOutlierGaussianLikelihood,
@@ -157,6 +156,103 @@ def load_standard_model(self, standard_model: Model) -> Self:
157156
self.load_state_dict(standard_model.state_dict())
158157
return self
159158

159+
def custom_fit(
160+
self,
161+
mll: MarginalLogLikelihood,
162+
*,
163+
numbers_of_outliers: list[int] | None = None,
164+
fractions_of_outliers: list[float] | None = None,
165+
timeout_sec: float | None = None,
166+
relevance_pursuit_optimizer: Callable = backward_relevance_pursuit,
167+
reset_parameters: bool = True,
168+
reset_dense_parameters: bool = False,
169+
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
170+
optimizer: Callable | None = None,
171+
closure_kwargs: dict[str, Any] | None = None,
172+
optimizer_kwargs: Mapping[str, Any] | None = None,
173+
) -> MarginalLogLikelihood:
174+
"""Fits a RobustRelevancePursuitGP model using the given marginal likelihood.
175+
176+
For details, see [Ament2024pursuit]_ or https://arxiv.org/abs/2410.24222.
177+
178+
Args:
179+
mll: The marginal likelihood to fit.
180+
numbers_of_outliers: An optional list of numbers of outliers to consider
181+
during relevance pursuit. By default, the algorithm falls back to a
182+
default list of fractions of outliers, see below.
183+
fractions_of_outliers: An optional list of fractions of outliers to
184+
consider if numbers_of_outliers is None. By default, the algorithm
185+
uses ``[0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0]``.
186+
relevance_pursuit_optimizer: The relevance pursuit optimizer to use.
187+
reset_parameters: If True, reset sparse parameters after each iteration.
188+
reset_dense_parameters: If True, reset dense parameters after each
189+
iteration.
190+
closure: A closure to compute loss and gradients.
191+
optimizer: The numerical optimizer.
192+
closure_kwargs: Additional arguments to pass to the closure.
193+
optimizer_kwargs: Additional arguments to pass to fit_gpytorch_mll.
194+
195+
Returns:
196+
The fitted marginal likelihood.
197+
"""
198+
if isinstance(mll, _ApproximateMarginalLogLikelihood):
199+
raise UnsupportedError(
200+
"Relevance Pursuit does not yet support approximate inference. "
201+
)
202+
203+
sparse_module = SparseOutlierNoise._from_model(mll.model)
204+
n = sparse_module.dim # equal to the number of training data points
205+
206+
if numbers_of_outliers is None:
207+
if fractions_of_outliers is None:
208+
fractions_of_outliers = FRACTIONS_OF_OUTLIERS
209+
210+
# list from which BMC chooses
211+
numbers_of_outliers = [int(p * n) for p in fractions_of_outliers]
212+
213+
optimizer_kwargs_: dict[str, Any] = (
214+
{} if optimizer_kwargs is None else dict(optimizer_kwargs)
215+
)
216+
if timeout_sec is not None:
217+
optimizer_kwargs_["timeout_sec"] = timeout_sec / len(numbers_of_outliers)
218+
219+
# Need to convert model to avoid recursion through fit_gpytorch_mll,
220+
# since relevance pursuit expects to call the base fit_gpytorch_mll.
221+
original_model = mll.model # Robust Relevance Pursuit Model
222+
mll.model = original_model.to_standard_model()
223+
sparse_module = SparseOutlierNoise._from_model(mll.model)
224+
sparse_module, model_trace = relevance_pursuit_optimizer(
225+
sparse_module=sparse_module,
226+
mll=mll,
227+
sparsity_levels=numbers_of_outliers,
228+
reset_parameters=reset_parameters,
229+
reset_dense_parameters=reset_dense_parameters,
230+
record_model_trace=True,
231+
# These are the args of the canonical mll fit routine
232+
closure=closure,
233+
optimizer=optimizer,
234+
closure_kwargs=closure_kwargs,
235+
optimizer_kwargs=optimizer_kwargs_,
236+
)
237+
238+
# Bayesian model comparison
239+
bmc_support_sizes, bmc_probabilities = get_posterior_over_support(
240+
SparseOutlierNoise,
241+
model_trace,
242+
prior_mean_of_support=original_model.prior_mean_of_support,
243+
)
244+
map_index = torch.argmax(bmc_probabilities)
245+
map_model = model_trace[map_index] # choosing model with highest BMC score
246+
# overwrite mll.model with chosen model
247+
mll.model = original_model # first restore original model pointer
248+
mll.model.load_standard_model(map_model)
249+
# Store the bmc results
250+
mll.model.bmc_support_sizes = bmc_support_sizes
251+
mll.model.bmc_probabilities = bmc_probabilities
252+
if mll.model.cache_model_trace:
253+
mll.model.model_trace = model_trace
254+
return mll
255+
160256

161257
class RobustRelevancePursuitSingleTaskGP(SingleTaskGP, RobustRelevancePursuitMixin):
162258
def __init__(
@@ -252,127 +348,3 @@ def to_standard_model(self) -> Model:
252348
if not is_training:
253349
model.eval()
254350
return model
255-
256-
257-
@FitGPyTorchMLL.register(
258-
MarginalLogLikelihood,
259-
SparseOutlierGaussianLikelihood,
260-
RobustRelevancePursuitMixin,
261-
)
262-
def _fit_rrp(
263-
mll: MarginalLogLikelihood,
264-
_: type[SparseOutlierGaussianLikelihood],
265-
__: type[RobustRelevancePursuitMixin],
266-
*,
267-
numbers_of_outliers: list[int] | None = None,
268-
fractions_of_outliers: list[float] | None = None,
269-
timeout_sec: float | None = None,
270-
relevance_pursuit_optimizer: Callable = backward_relevance_pursuit,
271-
reset_parameters: bool = True,
272-
reset_dense_parameters: bool = False,
273-
# fit_gpytorch_mll kwargs
274-
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
275-
optimizer: Callable | None = None,
276-
closure_kwargs: dict[str, Any] | None = None,
277-
optimizer_kwargs: Mapping[str, Any] | None = None,
278-
) -> MarginalLogLikelihood:
279-
"""Fits a RobustRelevancePursuitGP model using the given marginal likelihood.
280-
281-
For details, see [Ament2024pursuit]_ or https://arxiv.org/abs/2410.24222.
282-
283-
Args:
284-
mll: The marginal likelihood to fit.
285-
_: A likelihood, only directly used for dispatching.
286-
_: A model, only directly used for dispatching.
287-
numbers_of_outliers: An optional list of numbers of outliers to consider during
288-
relevance pursuit. By default, the algorithm falls back to a default list
289-
of fractions of outliers, see below.
290-
fractions_of_outliers: An optional list of fractions of outliers to consider if
291-
numbers_of_outliers is None. By default, the algorithm uses
292-
``[0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0]``.
293-
relevance_pursuit_optimizer: The relevance pursuit optimizer to use. By default,
294-
uses ``backward_relevance_pursuit``, which is generally the most powerful
295-
algorithm for challenging problems with a wide range of outliers. The
296-
``forward_relevance_pursuit`` algorithm can be efficient when the number of
297-
outliers is relatively small.
298-
reset_parameters: If True, we will reset the sparse parameters of the model
299-
after each iteration of the relevance pursuit algorithm.
300-
reset_dense_parameters: If True, we will reset the dense parameters of the model
301-
after each iteration of the relevance pursuit algorithm.
302-
closure: A closure to use to compute the loss and the gradients, see docstring
303-
of ``fit_gpytorch_mll`` for details.
304-
optimizer: The numerical optimizer, see docstring of ``fit_gpytorch_mll``.
305-
closure_kwargs: Additional arguments to pass to the ``closure`` function.
306-
optimizer_kwargs: Additional arguments to pass to ``fit_gpytorch_mll``.
307-
308-
Returns:
309-
The fitted marginal likelihood.
310-
"""
311-
sparse_module = SparseOutlierNoise._from_model(mll.model)
312-
n = sparse_module.dim # equal to the number of training data points
313-
314-
if numbers_of_outliers is None:
315-
if fractions_of_outliers is None:
316-
fractions_of_outliers = FRACTIONS_OF_OUTLIERS
317-
318-
# list from which BMC chooses
319-
numbers_of_outliers = [int(p * n) for p in fractions_of_outliers]
320-
321-
optimizer_kwargs_: dict[str, Any] = (
322-
{} if optimizer_kwargs is None else dict(optimizer_kwargs)
323-
)
324-
if timeout_sec is not None:
325-
optimizer_kwargs_["timeout_sec"] = timeout_sec / len(numbers_of_outliers)
326-
327-
# Need to convert model to avoid recursion through fit_gpytorch_mll dispatch, since
328-
# relevance pursuit expects to call the base fit_gpytorch_mll.
329-
original_model = mll.model # Robust Relevance Pursuit Model
330-
mll.model = original_model.to_standard_model()
331-
sparse_module = SparseOutlierNoise._from_model(mll.model)
332-
sparse_module, model_trace = relevance_pursuit_optimizer(
333-
sparse_module=sparse_module,
334-
mll=mll,
335-
sparsity_levels=numbers_of_outliers,
336-
reset_parameters=reset_parameters,
337-
reset_dense_parameters=reset_dense_parameters,
338-
record_model_trace=True,
339-
# These are the args of the canonical mll fit routine
340-
closure=closure,
341-
optimizer=optimizer,
342-
closure_kwargs=closure_kwargs,
343-
optimizer_kwargs=optimizer_kwargs_,
344-
)
345-
346-
# Bayesian model comparison
347-
bmc_support_sizes, bmc_probabilities = get_posterior_over_support(
348-
SparseOutlierNoise,
349-
model_trace,
350-
prior_mean_of_support=original_model.prior_mean_of_support,
351-
)
352-
map_index = torch.argmax(bmc_probabilities)
353-
map_model = model_trace[map_index] # choosing model with highest BMC score
354-
# overwrite mll.model with chosen model
355-
mll.model = original_model # first restore original model pointer
356-
mll.model.load_standard_model(map_model)
357-
# Store the bmc results
358-
mll.model.bmc_support_sizes = bmc_support_sizes
359-
mll.model.bmc_probabilities = bmc_probabilities
360-
if mll.model.cache_model_trace:
361-
mll.model.model_trace = model_trace
362-
return mll
363-
364-
365-
@FitGPyTorchMLL.register(
366-
_ApproximateMarginalLogLikelihood,
367-
SparseOutlierGaussianLikelihood,
368-
RobustRelevancePursuitMixin,
369-
)
370-
def _fit_rrp_approximate_mll(
371-
mll: _ApproximateMarginalLogLikelihood,
372-
_: type[SparseOutlierGaussianLikelihood],
373-
__: type[RobustRelevancePursuitMixin],
374-
**kwargs: Any,
375-
) -> None:
376-
raise UnsupportedError(
377-
"Relevance Pursuit does not yet support approximate inference. "
378-
)

0 commit comments

Comments
 (0)