Skip to content

Commit a4088cb

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Adding assign to load_state_dict implementations (meta-pytorch#3193)
Summary: Pull Request resolved: meta-pytorch#3193 Pull Request resolved: meta-pytorch#3080 This commit adds `assign` to `GPyTorchModel.load_state_dict` and other model types, to ensure consistency with `Module.load_state_dict`. Dependent on D87084496 (OSS here: https://github.com/cornellius-gp/gpytorch/pull/2691/commits). Reviewed By: saitcakmak Differential Revision: D86870038 fbshipit-source-id: bbbcd33dc3edd991e963a4a2554054fe0bd3551d
1 parent 9b192a4 commit a4088cb

4 files changed

Lines changed: 129 additions & 9 deletions

File tree

botorch/models/fully_bayesian.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,7 +1111,10 @@ def median_lengthscale(self) -> Tensor:
11111111
return lengthscale.median(0).values.squeeze(0)
11121112

11131113
def load_state_dict(
1114-
self, state_dict: Mapping[str, Any], strict: bool = True
1114+
self,
1115+
state_dict: Mapping[str, Any],
1116+
strict: bool = True,
1117+
assign: bool = False,
11151118
) -> None:
11161119
r"""Custom logic for loading the state dict.
11171120
@@ -1133,7 +1136,7 @@ def load_state_dict(
11331136
)
11341137
self.load_mcmc_samples(mcmc_samples=mcmc_samples)
11351138
# Load the actual samples from the state dict
1136-
super().load_state_dict(state_dict=state_dict, strict=strict)
1139+
super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
11371140

11381141

11391142
class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP):
@@ -1184,7 +1187,10 @@ def median_weight_variance(self) -> Tensor:
11841187
return weight_variance.median(0).values.squeeze(0)
11851188

11861189
def load_state_dict(
1187-
self, state_dict: Mapping[str, Any], strict: bool = True
1190+
self,
1191+
state_dict: Mapping[str, Any],
1192+
strict: bool = True,
1193+
assign: bool = False,
11881194
) -> None:
11891195
r"""Custom logic for loading the state dict.
11901196
@@ -1205,4 +1211,4 @@ def load_state_dict(
12051211
)
12061212
self.load_mcmc_samples(mcmc_samples=mcmc_samples)
12071213
# Load the actual samples from the state dict
1208-
super().load_state_dict(state_dict=state_dict, strict=strict)
1214+
super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)

botorch/models/gpytorch.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def load_state_dict(
329329
state_dict: Mapping[str, Any],
330330
strict: bool = True,
331331
keep_transforms: bool = True,
332+
assign: bool = False,
332333
) -> None:
333334
r"""Load the model state.
334335
@@ -338,9 +339,17 @@ def load_state_dict(
338339
keep_transforms: A boolean indicating whether to keep the input and outcome
339340
transforms. Doing so is useful when loading a model that was trained on
340341
a full set of data, and is later loaded with a subset of the data.
342+
assign: When set to ``False``, the properties of the tensors in the current
343+
module are preserved whereas setting it to ``True`` preserves
344+
properties of the Tensors in the state dict. The only
345+
exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`
346+
for which the value from the module is preserved. Default: ``False``.
341347
"""
348+
if assign:
349+
first_item = next(iter(state_dict.values()))
350+
self.to(first_item)
342351
if not keep_transforms:
343-
super().load_state_dict(state_dict, strict)
352+
super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
344353
return
345354

346355
should_outcome_transform = (
@@ -369,10 +378,12 @@ def load_state_dict(
369378
BotorchWarning,
370379
stacklevel=3,
371380
)
372-
super().load_state_dict(state_dict, strict)
381+
super().load_state_dict(
382+
state_dict=state_dict, strict=strict, assign=assign
383+
)
373384
return
374385

375-
super().load_state_dict(state_dict, strict)
386+
super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
376387

377388
if getattr(self, "input_transform", None) is not None:
378389
self.input_transform.eval()
@@ -764,8 +775,11 @@ def load_state_dict(
764775
self,
765776
state_dict: Mapping[str, Any],
766777
strict: bool = True,
778+
assign: bool = False,
767779
) -> None:
768-
return ModelList.load_state_dict(self, state_dict, strict)
780+
return ModelList.load_state_dict(
781+
self, state_dict=state_dict, strict=strict, assign=assign
782+
)
769783

770784
# pyre-fixme[14]: Inconsistent override in return types
771785
def posterior(

botorch/models/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ def load_state_dict(
582582
state_dict: Mapping[str, Any],
583583
strict: bool = True,
584584
keep_transforms: bool = True,
585+
assign: bool = False,
585586
) -> None:
586587
"""Initialize the fully Bayesian models before loading the state dict."""
587588
for i, m in enumerate(self.models):
@@ -590,7 +591,7 @@ def load_state_dict(
590591
for k, v in state_dict.items()
591592
if k.startswith(f"models.{i}.")
592593
}
593-
m.load_state_dict(filtered_dict, strict=strict)
594+
m.load_state_dict(filtered_dict, strict=strict, assign=assign)
594595

595596
def fantasize(
596597
self,

test/models/test_gpytorch.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from gpytorch.likelihoods import GaussianLikelihood
4545
from gpytorch.means import ConstantMean
4646
from gpytorch.models import ExactGP, IndependentModelList
47+
from gpytorch.priors import LogNormalPrior
4748
from gpytorch.settings import trace_mode
4849
from torch import Tensor
4950
from torch.nn.functional import one_hot
@@ -1042,6 +1043,104 @@ def test_load_state_dict_with_transforms(self):
10421043
)
10431044
)
10441045

1046+
def test_load_state_dict_assign_parameter(self):
1047+
"""Test that the assign parameter correctly controls tensor property
1048+
preservation.
1049+
1050+
With assign=False (default): properties of the current model's tensors are
1051+
preserved.
1052+
With assign=True: properties of the state dict's tensors are preserved.
1053+
"""
1054+
# Create base model with double precision
1055+
tkwargs_double = {"device": self.device, "dtype": torch.double}
1056+
train_X_double = torch.rand(5, 2, **tkwargs_double)
1057+
train_Y_double = torch.sin(train_X_double).sum(dim=1, keepdim=True)
1058+
1059+
# NOTE Due to issues with transformed priors in gpytorch, we refrain from
1060+
# instantiating a model with a LogNormal prior here.
1061+
model_specs_without_priors = {
1062+
"covar_module": RBFKernel(ard_num_dims=2),
1063+
"likelihood": GaussianLikelihood(),
1064+
}
1065+
base_model = SingleTaskGP(
1066+
train_X=train_X_double,
1067+
train_Y=train_Y_double,
1068+
**model_specs_without_priors,
1069+
**_get_input_output_transform(d=2, indices=[0, 1], m=1),
1070+
)
1071+
state_dict_double = base_model.state_dict()
1072+
1073+
# Create a new model with float32 precision (different dtype)
1074+
tkwargs_float = {"device": self.device, "dtype": torch.float}
1075+
train_X_float = torch.rand(5, 2, **tkwargs_float)
1076+
train_Y_float = torch.sin(train_X_float).sum(dim=1, keepdim=True)
1077+
1078+
# Test assign=False (default behavior)
1079+
model_assign_false = SingleTaskGP(
1080+
train_X=train_X_float,
1081+
train_Y=train_Y_float,
1082+
**model_specs_without_priors,
1083+
**_get_input_output_transform(d=2, indices=[0, 1], m=1),
1084+
)
1085+
1086+
# Load double precision state dict with assign=False
1087+
model_assign_false.load_state_dict(
1088+
state_dict_double, keep_transforms=False, assign=False
1089+
)
1090+
1091+
# With assign=False, the model should keep its original float32 dtype
1092+
self.assertEqual(model_assign_false.train_inputs[0].dtype, torch.float)
1093+
1094+
# Test assign=True
1095+
model_assign_true = SingleTaskGP(
1096+
train_X=train_X_float,
1097+
train_Y=train_Y_float,
1098+
**model_specs_without_priors,
1099+
**_get_input_output_transform(d=2, indices=[0, 1], m=1),
1100+
)
1101+
1102+
# Load double precision state dict with assign=True
1103+
model_assign_true.load_state_dict(
1104+
state_dict_double, keep_transforms=False, assign=True
1105+
)
1106+
1107+
# With assign=True, the model should adopt the state dict's double dtype
1108+
self.assertEqual(model_assign_true.train_inputs[0].dtype, torch.double)
1109+
self.assertEqual(
1110+
model_assign_true.train_inputs[0].dtype,
1111+
next(iter(state_dict_double.values())).dtype,
1112+
)
1113+
1114+
# Verify the two models have different dtypes
1115+
self.assertNotEqual(
1116+
model_assign_false.train_inputs[0].dtype,
1117+
model_assign_true.train_inputs[0].dtype,
1118+
)
1119+
1120+
base_model_with_prior = SingleTaskGP(
1121+
train_X=train_X_double,
1122+
train_Y=train_Y_double,
1123+
**_get_input_output_transform(d=2, indices=[0, 1], m=1),
1124+
)
1125+
state_dict_with_prior = base_model_with_prior.state_dict()
1126+
state_dict_double = base_model.state_dict()
1127+
model_assign_true_with_prior = SingleTaskGP(
1128+
train_X=train_X_float,
1129+
train_Y=train_Y_float,
1130+
covar_module=RBFKernel(
1131+
ard_num_dims=2, lengthscale_prior=LogNormalPrior(1.23, 2.34)
1132+
),
1133+
**_get_input_output_transform(d=2, indices=[0, 1], m=1),
1134+
)
1135+
1136+
model_assign_true_with_prior.load_state_dict(
1137+
state_dict_with_prior, keep_transforms=False, assign=True
1138+
)
1139+
self.assertAlmostEqual(
1140+
model_assign_true_with_prior.covar_module.lengthscale_prior.loc,
1141+
base_model_with_prior.covar_module.lengthscale_prior.loc,
1142+
)
1143+
10451144
def test_load_state_dict_no_transforms(self):
10461145
tkwargs = {"device": self.device, "dtype": torch.double}
10471146

0 commit comments

Comments
 (0)