Skip to content

Commit 2b5189e

Browse files
esantorellafacebook-github-bot
authored andcommitted
Fix loading state dict for SingleTaskVariationalGP/ApproximateGPyTorchModel (meta-pytorch#3251)
Summary: **Context**: See `https://github.com/meta-pytorch/botorch/issues/3250` . GPyTorchModel.load_state_dict previously accessed self.train_targets and self.train_inputs unconditionally, which failed for ApproximateGPyTorchModel because its training data lives on its `model` attribute (e.g. `model.model.train_targets`). **Changes**: - Extracted two new overridable methods on GPyTorchModel: _untransform_targets() (undo outcome transform and return raw targets) and _retransform_and_set_targets(). Note: `None` typing effectively replaces `hasattr` checks for `train_targets` - Used these methods in `load_state_dict` - Overrode both methods on ApproximateGPyTorchModel to read from / write to self.model.train_targets and self.model.train_inputs instead of self. Differential Revision: D98021112
1 parent 719d7b2 commit 2b5189e

3 files changed

Lines changed: 262 additions & 51 deletions

File tree

botorch/models/approximate_gp.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,41 @@ def __init__(
128128
self.likelihood = likelihood
129129
self._desired_num_outputs = num_outputs
130130

131+
def _untransform_targets(
132+
self,
133+
) -> tuple[Tensor, Tensor | None, Tensor] | None:
134+
r"""Extract and untransform training targets from the inner model.
135+
136+
Overrides ``GPyTorchModel._untransform_targets`` because
137+
``ApproximateGPyTorchModel`` stores ``train_targets`` and
138+
``train_inputs`` on ``self.model`` (the inner ``ApproximateGP``),
139+
not directly on ``self``.
140+
"""
141+
if not hasattr(self.model, "train_targets"):
142+
return None
143+
if getattr(self, "outcome_transform", None) is None:
144+
return None
145+
146+
Y = self.model.train_targets.unsqueeze(-1)
147+
X = self.model.train_inputs[0]
148+
Y, Yvar = self.outcome_transform.untransform(Y=Y, Yvar=None, X=X)
149+
return Y, Yvar, X
150+
151+
def _retransform_and_set_targets(
152+
self,
153+
Y: Tensor,
154+
Yvar: Tensor | None,
155+
X: Tensor,
156+
) -> None:
157+
r"""Re-apply the outcome transform and store targets on the inner model.
158+
159+
Overrides ``GPyTorchModel._retransform_and_set_targets`` because
160+
targets must be written to ``self.model.train_targets``.
161+
"""
162+
self.outcome_transform.eval()
163+
retransformed_Y, _ = self.outcome_transform(Y=Y, Yvar=Yvar, X=X)
164+
self.model.train_targets = retransformed_Y.squeeze(-1)
165+
131166
@property
132167
def num_outputs(self):
133168
return self._desired_num_outputs

botorch/models/gpytorch.py

Lines changed: 82 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -304,15 +304,12 @@ def _extract_targets_and_noise(self) -> tuple[Tensor, Tensor | None]:
304304
Y, Yvar = extract_targets_and_noise_single_output(self)
305305
return Y, Yvar
306306

307-
def _restore_targets_and_noise(
308-
self, Y: Tensor, Yvar: Tensor | None, strict: bool
309-
) -> None:
307+
def _restore_targets_and_noise(self, Y: Tensor, Yvar: Tensor | None) -> None:
310308
r"""Restore targets and noise variance to the model.
311309
312310
Args:
313311
Y: Targets tensor in shape ``batch_shape x n x m``.
314312
Yvar: Optional noise variance tensor in shape ``batch_shape x n x m``.
315-
strict: Whether to strictly enforce shape constraints.
316313
"""
317314
if self.num_outputs > 1:
318315
Y = Y.transpose(-1, -2)
@@ -321,9 +318,63 @@ def _restore_targets_and_noise(
321318
):
322319
Yvar = Yvar.transpose(-1, -2)
323320
self.likelihood.noise_covar.noise = Yvar
324-
self.set_train_data(targets=Y, strict=strict)
321+
self.set_train_data(targets=Y, strict=False)
325322
else:
326-
restore_targets_and_noise_single_output(self, Y, Yvar, strict)
323+
restore_targets_and_noise_single_output(
324+
model=self, Y=Y, Yvar=Yvar, strict=False
325+
)
326+
327+
def _untransform_targets(
328+
self,
329+
) -> tuple[Tensor, Tensor | None, Tensor] | None:
330+
r"""Extract training targets, undo the outcome transform, and return them.
331+
332+
Used by ``load_state_dict`` to save the untransformed targets before
333+
loading new parameters, so that the outcome transform can be re-applied
334+
afterward with the new transform state.
335+
336+
Subclasses that store training data somewhere other than ``self`` (e.g.
337+
``ApproximateGPyTorchModel`` stores it on ``self.model``) should
338+
override this method.
339+
340+
Returns:
341+
A tuple ``(Y, Yvar, X)`` of untransformed targets, noise variance,
342+
and training inputs — or ``None`` if the model has no outcome
343+
transform.
344+
"""
345+
if getattr(self, "outcome_transform", None) is None:
346+
return None
347+
348+
Y, Yvar = self._extract_targets_and_noise()
349+
X = self.train_inputs[0]
350+
Y, Yvar = self.outcome_transform.untransform(Y=Y, Yvar=Yvar, X=X)
351+
return Y, Yvar, X
352+
353+
def _retransform_and_set_targets(
354+
self,
355+
Y: Tensor,
356+
Yvar: Tensor | None,
357+
X: Tensor,
358+
) -> None:
359+
r"""Re-apply the outcome transform to targets and store them.
360+
361+
Called by ``load_state_dict`` after new parameters have been loaded,
362+
to re-transform the training targets under the updated outcome
363+
transform.
364+
365+
Subclasses that store training data somewhere other than ``self``
366+
should override this method.
367+
368+
Args:
369+
Y: Untransformed targets, shape ``batch_shape x n x m``.
370+
Yvar: Untransformed noise variance, or ``None``.
371+
X: Training inputs, shape ``batch_shape x n x d``.
372+
"""
373+
self.outcome_transform.eval()
374+
retransformed_Y, retransformed_Yvar = self.outcome_transform(
375+
Y=Y, Yvar=Yvar, X=X
376+
)
377+
self._restore_targets_and_noise(Y=retransformed_Y, Yvar=retransformed_Yvar)
327378

328379
def load_state_dict(
329380
self,
@@ -353,48 +404,34 @@ def load_state_dict(
353404
super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
354405
return
355406

356-
should_outcome_transform = (
357-
hasattr(self, "train_targets")
358-
and getattr(self, "outcome_transform", None) is not None
359-
)
360-
407+
# Before loading new parameters, untransform the current training
408+
# targets so they can be re-transformed under the new outcome
409+
# transform state. Returns None if no outcome transform or no
410+
# training data.
361411
with torch.no_grad():
362-
untransformed_Y, untransformed_Yvar = self._extract_targets_and_noise()
363-
X = self.train_inputs[0]
364-
365-
if should_outcome_transform:
366-
try:
367-
untransformed_Y, untransformed_Yvar = (
368-
self.outcome_transform.untransform(
369-
Y=untransformed_Y,
370-
Yvar=untransformed_Yvar,
371-
X=X,
372-
)
373-
)
374-
except NotImplementedError:
375-
warnings.warn(
376-
"Outcome transform does not support untransforming."
377-
"Cannot load the state dict with transforms preserved."
378-
"Setting keep_transforms=False.",
379-
BotorchWarning,
380-
stacklevel=3,
381-
)
382-
super().load_state_dict(
383-
state_dict=state_dict, strict=strict, assign=assign
384-
)
385-
return
412+
try:
413+
untransformed = self._untransform_targets()
414+
except NotImplementedError:
415+
warnings.warn(
416+
"Outcome transform does not support untransforming. "
417+
"Cannot load the state dict with transforms preserved. "
418+
"Setting keep_transforms=False.",
419+
BotorchWarning,
420+
stacklevel=3,
421+
)
422+
super().load_state_dict(
423+
state_dict=state_dict, strict=strict, assign=assign
424+
)
425+
return
386426

387427
super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
388428

389429
if getattr(self, "input_transform", None) is not None:
390430
self.input_transform.eval()
391431

392-
if should_outcome_transform:
393-
self.outcome_transform.eval()
394-
retransformed_Y, retransformed_Yvar = self.outcome_transform(
395-
Y=untransformed_Y, Yvar=untransformed_Yvar, X=X
396-
)
397-
self._restore_targets_and_noise(retransformed_Y, retransformed_Yvar, strict)
432+
if untransformed is not None:
433+
Y, Yvar, X = untransformed
434+
self._retransform_and_set_targets(Y=Y, Yvar=Yvar, X=X)
398435

399436

400437
# pyre-fixme[13]: uninitialized attributes _num_outputs, _input_batch_shape,
@@ -935,17 +972,16 @@ def _extract_targets_and_noise(self) -> tuple[Tensor, Tensor | None]:
935972
"""
936973
return extract_targets_and_noise_single_output(self)
937974

938-
def _restore_targets_and_noise(
939-
self, Y: Tensor, Yvar: Tensor | None, strict: bool
940-
) -> None:
975+
def _restore_targets_and_noise(self, Y: Tensor, Yvar: Tensor | None) -> None:
941976
r"""Restore targets and noise variance for multi-task models.
942977
943978
Args:
944979
Y: Targets tensor in shape ``batch_shape x n x m``.
945980
Yvar: Optional noise variance tensor in shape ``batch_shape x n x m``.
946-
strict: Whether to strictly enforce shape constraints.
947981
"""
948-
restore_targets_and_noise_single_output(self, Y, Yvar, strict)
982+
restore_targets_and_noise_single_output(
983+
model=self, Y=Y, Yvar=Yvar, strict=False
984+
)
949985

950986
def _apply_noise(
951987
self,

test/models/test_approximate_gp.py

Lines changed: 145 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
SingleTaskVariationalGP,
1818
)
1919
from botorch.models.transforms.input import Normalize
20-
from botorch.models.transforms.outcome import Log
20+
from botorch.models.transforms.outcome import Log, Standardize
2121
from botorch.models.utils.inducing_point_allocators import (
2222
GreedyImprovementReduction,
2323
GreedyVarianceReduction,
@@ -63,21 +63,161 @@ def test_initialization(self):
6363
)
6464
self.assertEqual(model.num_outputs, 2)
6565

66+
def test_load_state_dict(self) -> None:
67+
test_X = torch.rand(5, 1, device=self.device)
68+
69+
for label, train_Y in [
70+
("with_train_Y", self.train_Y),
71+
("no_train_Y", None),
72+
]:
73+
with self.subTest(label=label):
74+
model = ApproximateGPyTorchModel(
75+
train_X=self.train_X,
76+
train_Y=train_Y,
77+
likelihood=BetaLikelihood().to(self.device),
78+
)
79+
state_dict = model.state_dict()
80+
81+
restored = ApproximateGPyTorchModel(
82+
train_X=self.train_X,
83+
train_Y=train_Y,
84+
likelihood=BetaLikelihood().to(self.device),
85+
)
86+
restored.load_state_dict(state_dict=state_dict)
87+
restored_state = restored.state_dict()
88+
self.assertEqual(set(state_dict.keys()), set(restored_state.keys()))
89+
for key in state_dict:
90+
self.assertTrue(
91+
torch.equal(state_dict[key], restored_state[key]),
92+
msg=f"Mismatch for key {key}",
93+
)
94+
95+
model.eval()
96+
restored.eval()
97+
torch.manual_seed(0)
98+
orig_posterior = model.posterior(test_X)
99+
torch.manual_seed(0)
100+
restored_posterior = restored.posterior(test_X)
101+
self.assertAllClose(orig_posterior.mean, restored_posterior.mean)
102+
self.assertAllClose(
103+
orig_posterior.variance, restored_posterior.variance
104+
)
105+
66106

67107
class TestSingleTaskVariationalGP(BotorchTestCase):
68108
def setUp(self):
69109
super().setUp()
70-
train_X = torch.rand(10, 1, device=self.device)
71-
train_y = torch.sin(train_X) + torch.randn_like(train_X) * 0.2
110+
self.train_X = torch.rand(10, 1, device=self.device)
111+
self.train_y = torch.sin(self.train_X) + torch.randn_like(self.train_X) * 0.2
72112

73113
self.model = SingleTaskVariationalGP(
74-
train_X=train_X, likelihood=GaussianLikelihood()
114+
train_X=self.train_X,
115+
train_Y=self.train_y,
116+
likelihood=GaussianLikelihood(),
117+
outcome_transform=Standardize(m=1),
75118
).to(self.device)
76119

77120
mll = VariationalELBO(self.model.likelihood, self.model.model, num_data=10)
78-
loss = -mll(self.model.likelihood(self.model(train_X)), train_y).sum()
121+
loss = -mll(self.model.likelihood(self.model(self.train_X)), self.train_y).sum()
79122
loss.backward()
80123

124+
def test_load_state_dict(self) -> None:
125+
test_X = torch.rand(5, 1, device=self.device)
126+
127+
for label, train_Y in [
128+
("with_train_Y", self.train_y),
129+
("no_train_Y", None),
130+
]:
131+
with self.subTest(label=label):
132+
model = SingleTaskVariationalGP(
133+
train_X=self.train_X,
134+
train_Y=train_Y,
135+
likelihood=BetaLikelihood(),
136+
).to(self.device)
137+
state_dict = model.state_dict()
138+
139+
restored = SingleTaskVariationalGP(
140+
train_X=self.train_X,
141+
train_Y=train_Y,
142+
likelihood=BetaLikelihood(),
143+
).to(self.device)
144+
restored.load_state_dict(state_dict=state_dict)
145+
restored_state = restored.state_dict()
146+
self.assertEqual(set(state_dict.keys()), set(restored_state.keys()))
147+
for key in state_dict:
148+
self.assertTrue(
149+
torch.equal(state_dict[key], restored_state[key]),
150+
msg=f"Mismatch for key {key}",
151+
)
152+
153+
# Posterior numerical identity. manual_seed is needed because
154+
# CholeskyVariationalDistribution.initialize_variational_distribution
155+
# adds random noise on the first posterior call.
156+
model.eval()
157+
restored.eval()
158+
torch.manual_seed(0)
159+
orig_posterior = model.posterior(test_X)
160+
torch.manual_seed(0)
161+
restored_posterior = restored.posterior(test_X)
162+
self.assertAllClose(orig_posterior.mean, restored_posterior.mean)
163+
self.assertAllClose(
164+
orig_posterior.variance, restored_posterior.variance
165+
)
166+
167+
# Test keep_transforms with different training data (CV-style).
168+
# The restored model is built with one fewer data point, so its
169+
# Standardize means/stdvs differ from the source model's.
170+
with self.subTest("keep_transforms"):
171+
model = SingleTaskVariationalGP(
172+
train_X=self.train_X,
173+
train_Y=self.train_y,
174+
outcome_transform=Standardize(m=1),
175+
input_transform=Normalize(d=1),
176+
).to(self.device)
177+
state_dict = model.state_dict()
178+
original_train_targets = model.model.train_targets.clone()
179+
180+
cv_X = self.train_X[:-1]
181+
cv_Y = self.train_y[:-1]
182+
183+
for keep_transforms in [True, False]:
184+
with self.subTest(keep_transforms=keep_transforms):
185+
restored = SingleTaskVariationalGP(
186+
train_X=cv_X,
187+
train_Y=cv_Y,
188+
outcome_transform=Standardize(m=1),
189+
input_transform=Normalize(d=1),
190+
).to(self.device)
191+
restored.load_state_dict(
192+
state_dict=state_dict, keep_transforms=keep_transforms
193+
)
194+
195+
if keep_transforms:
196+
# Transform params are loaded from state_dict, and
197+
# train_targets are re-standardized under the loaded
198+
# transform, so they match the original (minus the
199+
# dropped point).
200+
self.assertAllClose(
201+
restored.model.train_targets,
202+
original_train_targets[..., :-1],
203+
)
204+
self.assertTrue(
205+
torch.equal(
206+
restored.outcome_transform.means,
207+
state_dict["outcome_transform.means"],
208+
)
209+
)
210+
else:
211+
# Transform params are loaded but train_targets are
212+
# NOT re-standardized, so they still reflect the
213+
# cv_Y-based standardization and won't match.
214+
self.assertFalse(
215+
torch.allclose(
216+
restored.model.train_targets,
217+
original_train_targets[..., :-1],
218+
)
219+
)
220+
81221
def test_posterior(self):
82222
# basic test of checking that the posterior works as intended
83223
test_x = torch.rand(30, 1, device=self.device)

0 commit comments

Comments
 (0)