Skip to content

Commit c895a8d

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Apply input transforms when computing MLL in model closures (#2527)
Summary: Pull Request resolved: #2527 During model training, the input transforms are applied in `model.forward`. While evaluating the model closures, we pass in the train inputs to the `mll`, which passed them down to the `likelihood`. If we don't transform the inputs before passing them into `mll`, we end up evaluating `model.forward` and `likelihood` using different inputs. This is not an issue during the `posterior` evaluation, since the transforms are applied in `model.posterior` before being passed to `model.__call__` and `likelihood`. This diff updates the model closures to transform the inputs before passing them into `mll`. Fixes #2515 Reviewed By: SebastianAment Differential Revision: D62497392 fbshipit-source-id: 9850c6529eea336589c2e2bbae400a4a9dc87f12
1 parent db96db3 commit c895a8d

File tree

3 files changed

+122
-34
lines changed

3 files changed

+122
-34
lines changed

botorch/models/gp_regression.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -345,12 +345,17 @@ def __init__(
345345
MIN_INFERRED_NOISE_LEVEL, transform=None, initial_value=1.0
346346
),
347347
)
348+
# Likelihood will always get evaluated with transformed X, so we need to
349+
# transform the training data before constructing the noise model.
350+
with torch.no_grad():
351+
transformed_X = self.transform_inputs(
352+
X=train_X, input_transform=input_transform
353+
)
348354
noise_model = SingleTaskGP(
349-
train_X=train_X,
355+
train_X=transformed_X,
350356
train_Y=train_Yvar,
351357
likelihood=noise_likelihood,
352358
outcome_transform=Log(),
353-
input_transform=input_transform,
354359
)
355360
likelihood = _GaussianLikelihoodBase(HeteroskedasticNoise(noise_model))
356361
# This is hacky -- this class used to inherit from SingleTaskGP, but it

botorch/optim/closures/model_closures.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from __future__ import annotations
1010

1111
from collections.abc import Sequence
12-
1312
from itertools import chain, repeat
1413
from types import NoneType
1514
from typing import Any, Callable, Optional
@@ -174,9 +173,17 @@ def _get_loss_closure_exact_internal(
174173
r"""ExactMarginalLogLikelihood loss closure with internally managed data."""
175174

176175
def closure(**kwargs: Any) -> Tensor:
177-
model_output = mll.model(*mll.model.train_inputs)
176+
model = mll.model
177+
# The inputs will get transformed in forward here.
178+
model_output = model(*model.train_inputs)
178179
log_likelihood = mll(
179-
model_output, mll.model.train_targets, *mll.model.train_inputs, **kwargs
180+
model_output,
181+
model.train_targets,
182+
# During model training, the model inputs get transformed in the forward
183+
# pass. The train_inputs property is not transformed yet, so we need to
184+
# transform it before passing it to the likelihood for consistency.
185+
*(model.transform_inputs(X=t_in) for t_in in model.train_inputs),
186+
**kwargs,
180187
)
181188
return -log_likelihood
182189

@@ -190,11 +197,19 @@ def _get_loss_closure_sum_internal(
190197
r"""SumMarginalLogLikelihood loss closure with internally managed data."""
191198

192199
def closure(**kwargs: Any) -> Tensor:
193-
model_output = mll.model(*mll.model.train_inputs)
200+
model = mll.model
201+
# The inputs will get transformed in forward here.
202+
model_output = model(*model.train_inputs)
194203
log_likelihood = mll(
195204
model_output,
196-
mll.model.train_targets,
197-
*map(list, mll.model.train_inputs),
205+
model.train_targets,
206+
# During model training, the model inputs get transformed in the forward
207+
# pass. The train_inputs property is not transformed yet, so we need to
208+
# transform it before passing it to the likelihood for consistency.
209+
*(
210+
(model.transform_inputs(X=t_in) for t_in in sub_t_in)
211+
for sub_t_in in model.train_inputs
212+
),
198213
**kwargs,
199214
)
200215
return -log_likelihood

test/optim/closures/test_model_closures.py

+94-26
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,68 @@
1717
)
1818
from botorch.utils.testing import BotorchTestCase
1919
from gpytorch import settings as gpytorch_settings
20+
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
2021
from gpytorch.mlls import ExactMarginalLogLikelihood, SumMarginalLogLikelihood
22+
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
23+
from gpytorch.module import Module
24+
from torch import Tensor
2125
from torch.utils.data import DataLoader, TensorDataset
2226

2327

28+
# Mock wrapping the __call__ directly is leading to errors like
29+
# TypeError: super(type, obj): obj must be an instance or subtype of type
30+
# so, doing this manually here.
31+
class WrapperLikelihood(GaussianLikelihood):
32+
def __init__(self, base_likelihood: GaussianLikelihood):
33+
"""A wrapper around a GaussianLikelihood that stores the call args."""
34+
Module.__init__(self)
35+
self.base_likelihood = base_likelihood
36+
self.call_args = []
37+
38+
def __call__(self, *args, **kwargs):
39+
# Store the train inputs arg for testing.
40+
self.call_args.append(args[1])
41+
return self.base_likelihood(*args, **kwargs)
42+
43+
44+
def _get_mlls(
45+
device: torch.device, wrap_likelihood: bool = False
46+
) -> tuple[Tensor, list[MarginalLogLikelihood]]:
47+
"""Returns the train X, along two MLLs: one for a SingleTaskGP and
48+
one for a ModelListGP.
49+
50+
Args:
51+
device: The device to use.
52+
wrap_likelihood: If True, wrap the likelihood in a WrapperLikelihood.
53+
This is useful for comparing call args later.
54+
"""
55+
with torch.random.fork_rng():
56+
torch.manual_seed(0)
57+
# Inputs are not in the unit cube to ensure input transform is applied.
58+
train_X = torch.linspace(0, 5, 10).unsqueeze(-1)
59+
train_Y = torch.sin((2 * pi) * train_X)
60+
train_Y = train_Y + 0.1 * torch.randn_like(train_Y)
61+
mlls = []
62+
model = SingleTaskGP(
63+
train_X=train_X,
64+
train_Y=train_Y,
65+
input_transform=Normalize(d=1),
66+
outcome_transform=Standardize(m=1),
67+
)
68+
if wrap_likelihood:
69+
model.likelihood = WrapperLikelihood(model.likelihood)
70+
mll = ExactMarginalLogLikelihood(model.likelihood, model)
71+
mlls.append(mll.to(device=device, dtype=torch.double))
72+
73+
model = ModelListGP(model, model)
74+
mll = SumMarginalLogLikelihood(model.likelihood, model)
75+
mlls.append(mll.to(device=device, dtype=torch.double))
76+
return train_X.to(device=device, dtype=torch.double), mlls
77+
78+
2479
class TestLossClosures(BotorchTestCase):
25-
def setUp(self):
26-
super().setUp()
27-
with torch.random.fork_rng():
28-
torch.manual_seed(0)
29-
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
30-
train_Y = torch.sin((2 * pi) * train_X)
31-
train_Y = train_Y + 0.1 * torch.randn_like(train_Y)
32-
33-
self.mlls = {}
34-
model = SingleTaskGP(
35-
train_X=train_X,
36-
train_Y=train_Y,
37-
input_transform=Normalize(d=1),
38-
outcome_transform=Standardize(m=1),
39-
)
40-
mll = ExactMarginalLogLikelihood(model.likelihood, model)
41-
self.mlls[type(mll), type(model.likelihood), type(model)] = mll.to(self.device)
42-
43-
model = ModelListGP(model, model)
44-
mll = SumMarginalLogLikelihood(model.likelihood, model)
45-
self.mlls[type(mll), type(model.likelihood), type(model)] = mll.to(self.device)
46-
47-
def test_main(self):
48-
for mll in self.mlls.values():
80+
def test_main(self) -> None:
81+
for mll in _get_mlls(device=self.device)[1]:
4982
out = mll.model(*mll.model.train_inputs)
5083
loss = -mll(out, mll.model.train_targets).sum()
5184
loss.backward()
@@ -63,8 +96,8 @@ def test_main(self):
6396
self.assertTrue(loss.equal(_loss))
6497
self.assertTrue(all(a.equal(b) for a, b in zip_longest(grads, _grads)))
6598

66-
def test_data_loader(self):
67-
for mll in self.mlls.values():
99+
def test_data_loader(self) -> None:
100+
for mll in _get_mlls(device=self.device)[1]:
68101
if type(mll) is not ExactMarginalLogLikelihood:
69102
continue
70103

@@ -86,3 +119,38 @@ def test_data_loader(self):
86119
closure = get_loss_closure_with_grads(mll, params, data_loader=loader)
87120
with self.assertRaisesRegex(TypeError, "Expected .* a batch of tensors"):
88121
closure()
122+
123+
def test_with_input_transforms(self) -> None:
124+
# This test reproduces the bug reported in issue #2515.
125+
train_X, mlls = _get_mlls(device=self.device, wrap_likelihood=True)
126+
for mll in mlls:
127+
if isinstance(mll, SumMarginalLogLikelihood):
128+
# The likelihood is called twice here since it is the same
129+
# likelihood in both child models.
130+
likelihood = mll.model.models[0].likelihood
131+
expected_calls1 = 2 # In the closure call.
132+
expected_calls2 = 6 # Closure + posterior calls.
133+
else:
134+
likelihood = mll.model.likelihood
135+
expected_calls1 = 1 # In the closure call.
136+
expected_calls2 = 4 # Closure + posterior calls.
137+
likelihood.call_args = [] # reset since it is shared between the models.
138+
params = {n: p for n, p in mll.named_parameters() if p.requires_grad}
139+
# Evaluate the closure to mimic the model fitting process.
140+
mll.train()
141+
closure = get_loss_closure_with_grads(mll, params)
142+
closure()
143+
self.assertEqual(len(likelihood.call_args), expected_calls1)
144+
# Call the model posterior to reproduce post-fitting usage.
145+
mll.model.posterior(train_X, observation_noise=True)
146+
# Compare the call args to ensure they're all the same.
147+
# Likelihood is called twice on model(X) and once for adding the noise.
148+
self.assertEqual(len(likelihood.call_args), expected_calls2)
149+
arg0 = likelihood.call_args[0]
150+
for i in range(1, expected_calls2):
151+
argi = likelihood.call_args[i]
152+
# The arg may be a tensor or a single element list of the tensor.
153+
self.assertAllClose(
154+
arg0 if isinstance(arg0, Tensor) else arg0[0],
155+
argi if isinstance(argi, Tensor) else argi[0],
156+
)

0 commit comments

Comments
 (0)