Skip to content

Commit 91df5fe

Browse files
esantorellameta-codesync[bot]
authored andcommitted
Replace GetLossClosure dispatcher with isinstance checks (#3235)
Summary: Pull Request resolved: #3235 **Context**: See D96592835 for stack overview. **This PR**: This commit removes the GetLossClosure dispatcher in `get_loss_closure` and replaces it with isinstance checks for the `custom_loss` methods added in D96592835. . The routing logic is now: 1. Check for compute_loss protocol on MLL or model (added in commit 1) 2. If data_loader is provided, use _get_loss_closure_fallback_external 3. isinstance(mll, ExactMarginalLogLikelihood) -> _get_loss_closure_exact_internal 4. isinstance(mll, SumMarginalLogLikelihood) -> _get_loss_closure_sum_internal 5. Fallback -> _get_loss_closure_fallback_internal Also cleans up: - Removes unused type arguments (_: object, __: object, ___: None) from closure factory functions - Removes **ignore: Any parameters that were only needed for dispatcher compatibility - Removes **kwargs from get_loss_closure and get_loss_closure_with_grads (no callers pass extra kwargs) - Removes imports of Dispatcher, type_bypassing_encoder, NoneType Reviewed By: saitcakmak Differential Revision: D96592824 fbshipit-source-id: 1a3672d9725ae3d27efb5bf78052ce18a9a5b0d7
1 parent b6d71f1 commit 91df5fe

2 files changed

Lines changed: 21 additions & 35 deletions

File tree

botorch/optim/closures/model_closures.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@
1111
from collections.abc import Callable, Sequence
1212
from functools import partial
1313
from itertools import chain, repeat
14-
from types import NoneType
1514
from typing import Any
1615

1716
from botorch.optim.closures.core import ForwardBackwardClosure
18-
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
1917
from gpytorch.mlls import (
2018
ExactMarginalLogLikelihood,
2119
MarginalLogLikelihood,
@@ -24,28 +22,18 @@
2422
from torch import Tensor
2523
from torch.utils.data import DataLoader
2624

27-
GetLossClosure = Dispatcher("get_loss_closure", encoder=type_bypassing_encoder)
28-
2925

3026
def get_loss_closure(
3127
mll: MarginalLogLikelihood,
3228
data_loader: DataLoader | None = None,
33-
**kwargs: Any,
3429
) -> Callable[[], Tensor]:
35-
r"""Public API for GetLossClosure dispatcher.
36-
37-
This method, and the dispatcher that powers it, acts as a clearing house
38-
for factory functions that define how ``mll`` is evaluated.
30+
r"""Factory function for creating loss closures from MarginalLogLikelihoods.
3931
40-
Users may specify custom evaluation routines by registering a factory function
41-
with GetLossClosure. These factories should be registered using the type signature
32+
This method acts as a clearing house for factory functions that define how
33+
``mll`` is evaluated.
4234
43-
``Type[MarginalLogLikelihood], Type[Likelihood], Type[Model],
44-
Type[DataLoader]``.
45-
46-
The final argument, Type[DataLoader], is optional. Evaluation routines that
47-
obtain training data from, e.g., ``mll.model`` should register this argument as
48-
``type(None)``.
35+
Users may specify custom evaluation routines by passing an ``mll`` or an
36+
``mll.model`` with a method ``compute_custom_loss``.
4937
5038
Args:
5139
mll: A MarginalLogLikelihood instance whose negative defines the loss.
@@ -61,16 +49,22 @@ def get_loss_closure(
6149
if hasattr(mll.model, "compute_custom_loss"):
6250
return partial(mll.model.compute_custom_loss, mll=mll)
6351

64-
return GetLossClosure(
65-
mll, type(mll.likelihood), type(mll.model), data_loader, **kwargs
66-
)
52+
if data_loader is not None:
53+
return _get_loss_closure_fallback_external(mll=mll, data_loader=data_loader)
54+
55+
if isinstance(mll, ExactMarginalLogLikelihood):
56+
return _get_loss_closure_exact_internal(mll=mll)
57+
58+
if isinstance(mll, SumMarginalLogLikelihood):
59+
return _get_loss_closure_sum_internal(mll=mll)
60+
61+
return _get_loss_closure_fallback_internal(mll=mll)
6762

6863

6964
def get_loss_closure_with_grads(
7065
mll: MarginalLogLikelihood,
7166
parameters: dict[str, Tensor],
7267
data_loader: DataLoader | None = None,
73-
**kwargs: Any,
7468
) -> ForwardBackwardClosure:
7569
"""
7670
Add a backward pass to a loss closure obtained by calling
@@ -83,23 +77,18 @@ def get_loss_closure_with_grads(
8377
parameters: A dictionary of tensors whose ``grad`` fields are to be returned.
8478
data_loader: An optional DataLoader instance for cases where training
8579
data is passed in rather than obtained from ``mll.model``.
86-
kwargs: Keyword arguments passed to ``get_loss_closure``.
8780
8881
Returns:
8982
A closure that takes zero positional arguments and returns the reduced and
9083
negated value of ``mll`` along with the gradients of ``parameters``.
9184
"""
92-
loss_closure = get_loss_closure(mll, data_loader=data_loader, **kwargs)
85+
loss_closure = get_loss_closure(mll=mll, data_loader=data_loader)
9386
return ForwardBackwardClosure(forward=loss_closure, parameters=parameters)
9487

9588

96-
@GetLossClosure.register(MarginalLogLikelihood, object, object, DataLoader)
9789
def _get_loss_closure_fallback_external(
9890
mll: MarginalLogLikelihood,
99-
_likelihood_type: object,
100-
_model_type: object,
10191
data_loader: DataLoader,
102-
**ignore: Any,
10392
) -> Callable[[], Tensor]:
10493
r"""Fallback loss closure with externally provided data."""
10594
batch_generator = chain.from_iterable(iter(data_loader) for _ in repeat(None))
@@ -120,9 +109,8 @@ def closure(**kwargs: Any) -> Tensor:
120109
return closure
121110

122111

123-
@GetLossClosure.register(MarginalLogLikelihood, object, object, NoneType)
124112
def _get_loss_closure_fallback_internal(
125-
mll: MarginalLogLikelihood, _: object, __: object, ___: None, **ignore: Any
113+
mll: MarginalLogLikelihood,
126114
) -> Callable[[], Tensor]:
127115
r"""Fallback loss closure with internally managed data."""
128116

@@ -134,9 +122,8 @@ def closure(**kwargs: Any) -> Tensor:
134122
return closure
135123

136124

137-
@GetLossClosure.register(ExactMarginalLogLikelihood, object, object, NoneType)
138125
def _get_loss_closure_exact_internal(
139-
mll: ExactMarginalLogLikelihood, _: object, __: object, ___: None, **ignore: Any
126+
mll: ExactMarginalLogLikelihood,
140127
) -> Callable[[], Tensor]:
141128
r"""ExactMarginalLogLikelihood loss closure with internally managed data."""
142129

@@ -158,9 +145,8 @@ def closure(**kwargs: Any) -> Tensor:
158145
return closure
159146

160147

161-
@GetLossClosure.register(SumMarginalLogLikelihood, object, object, NoneType)
162148
def _get_loss_closure_sum_internal(
163-
mll: SumMarginalLogLikelihood, _: object, __: object, ___: None, **ignore: Any
149+
mll: SumMarginalLogLikelihood,
164150
) -> Callable[[], Tensor]:
165151
r"""SumMarginalLogLikelihood loss closure with internally managed data."""
166152

test/optim/closures/test_model_closures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ def my_compute_custom_loss(
104104
self.assertEqual(result(), torch.tensor(-99.0))
105105
del mll.model.compute_custom_loss
106106

107-
# Without compute_custom_loss, get_loss_closure uses the dispatcher.
108-
with self.subTest("dispatcher_fallback"):
107+
# Without compute_custom_loss, get_loss_closure uses isinstance checks.
108+
with self.subTest("isinstance_fallback"):
109109
_, mlls = _get_mlls(device=self.device)
110110
mll = mlls[0]
111111
self.assertFalse(hasattr(mll, "compute_custom_loss"))

0 commit comments

Comments
 (0)