|
44 | 44 | from gpytorch.likelihoods import GaussianLikelihood |
45 | 45 | from gpytorch.means import ConstantMean |
46 | 46 | from gpytorch.models import ExactGP, IndependentModelList |
| 47 | +from gpytorch.priors import LogNormalPrior |
47 | 48 | from gpytorch.settings import trace_mode |
48 | 49 | from torch import Tensor |
49 | 50 | from torch.nn.functional import one_hot |
@@ -1042,6 +1043,104 @@ def test_load_state_dict_with_transforms(self): |
1042 | 1043 | ) |
1043 | 1044 | ) |
1044 | 1045 |
|
| 1046 | + def test_load_state_dict_assign_parameter(self): |
| 1047 | + """Test that the assign parameter correctly controls tensor property |
| 1048 | + preservation. |
| 1049 | +
|
| 1050 | + With assign=False (default): properties of the current model's tensors are |
| 1051 | + preserved. |
| 1052 | + With assign=True: properties of the state dict's tensors are preserved. |
| 1053 | + """ |
| 1054 | + # Create base model with double precision |
| 1055 | + tkwargs_double = {"device": self.device, "dtype": torch.double} |
| 1056 | + train_X_double = torch.rand(5, 2, **tkwargs_double) |
| 1057 | + train_Y_double = torch.sin(train_X_double).sum(dim=1, keepdim=True) |
| 1058 | + |
| 1059 | + # NOTE Due to issues with transformed priors in gpytorch, we refrain from |
| 1060 | + # instantiating a model with a LogNormal prior here. |
| 1061 | + model_specs_without_priors = { |
| 1062 | + "covar_module": RBFKernel(ard_num_dims=2), |
| 1063 | + "likelihood": GaussianLikelihood(), |
| 1064 | + } |
| 1065 | + base_model = SingleTaskGP( |
| 1066 | + train_X=train_X_double, |
| 1067 | + train_Y=train_Y_double, |
| 1068 | + **model_specs_without_priors, |
| 1069 | + **_get_input_output_transform(d=2, indices=[0, 1], m=1), |
| 1070 | + ) |
| 1071 | + state_dict_double = base_model.state_dict() |
| 1072 | + |
| 1073 | + # Create a new model with float32 precision (different dtype) |
| 1074 | + tkwargs_float = {"device": self.device, "dtype": torch.float} |
| 1075 | + train_X_float = torch.rand(5, 2, **tkwargs_float) |
| 1076 | + train_Y_float = torch.sin(train_X_float).sum(dim=1, keepdim=True) |
| 1077 | + |
| 1078 | + # Test assign=False (default behavior) |
| 1079 | + model_assign_false = SingleTaskGP( |
| 1080 | + train_X=train_X_float, |
| 1081 | + train_Y=train_Y_float, |
| 1082 | + **model_specs_without_priors, |
| 1083 | + **_get_input_output_transform(d=2, indices=[0, 1], m=1), |
| 1084 | + ) |
| 1085 | + |
| 1086 | + # Load double precision state dict with assign=False |
| 1087 | + model_assign_false.load_state_dict( |
| 1088 | + state_dict_double, keep_transforms=False, assign=False |
| 1089 | + ) |
| 1090 | + |
| 1091 | + # With assign=False, the model should keep its original float32 dtype |
| 1092 | + self.assertEqual(model_assign_false.train_inputs[0].dtype, torch.float) |
| 1093 | + |
| 1094 | + # Test assign=True |
| 1095 | + model_assign_true = SingleTaskGP( |
| 1096 | + train_X=train_X_float, |
| 1097 | + train_Y=train_Y_float, |
| 1098 | + **model_specs_without_priors, |
| 1099 | + **_get_input_output_transform(d=2, indices=[0, 1], m=1), |
| 1100 | + ) |
| 1101 | + |
| 1102 | + # Load double precision state dict with assign=True |
| 1103 | + model_assign_true.load_state_dict( |
| 1104 | + state_dict_double, keep_transforms=False, assign=True |
| 1105 | + ) |
| 1106 | + |
| 1107 | + # With assign=True, the model should adopt the state dict's double dtype |
| 1108 | + self.assertEqual(model_assign_true.train_inputs[0].dtype, torch.double) |
| 1109 | + self.assertEqual( |
| 1110 | + model_assign_true.train_inputs[0].dtype, |
| 1111 | + next(iter(state_dict_double.values())).dtype, |
| 1112 | + ) |
| 1113 | + |
| 1114 | + # Verify the two models have different dtypes |
| 1115 | + self.assertNotEqual( |
| 1116 | + model_assign_false.train_inputs[0].dtype, |
| 1117 | + model_assign_true.train_inputs[0].dtype, |
| 1118 | + ) |
| 1119 | + |
| 1120 | + base_model_with_prior = SingleTaskGP( |
| 1121 | + train_X=train_X_double, |
| 1122 | + train_Y=train_Y_double, |
| 1123 | + **_get_input_output_transform(d=2, indices=[0, 1], m=1), |
| 1124 | + ) |
| 1125 | + state_dict_with_prior = base_model_with_prior.state_dict() |
| 1126 | + state_dict_double = base_model.state_dict() |
| 1127 | + model_assign_true_with_prior = SingleTaskGP( |
| 1128 | + train_X=train_X_float, |
| 1129 | + train_Y=train_Y_float, |
| 1130 | + covar_module=RBFKernel( |
| 1131 | + ard_num_dims=2, lengthscale_prior=LogNormalPrior(1.23, 2.34) |
| 1132 | + ), |
| 1133 | + **_get_input_output_transform(d=2, indices=[0, 1], m=1), |
| 1134 | + ) |
| 1135 | + |
| 1136 | + model_assign_true_with_prior.load_state_dict( |
| 1137 | + state_dict_with_prior, keep_transforms=False, assign=True |
| 1138 | + ) |
| 1139 | + self.assertAlmostEqual( |
| 1140 | + model_assign_true_with_prior.covar_module.lengthscale_prior.loc, |
| 1141 | + base_model_with_prior.covar_module.lengthscale_prior.loc, |
| 1142 | + ) |
| 1143 | + |
1045 | 1144 | def test_load_state_dict_no_transforms(self): |
1046 | 1145 | tkwargs = {"device": self.device, "dtype": torch.double} |
1047 | 1146 |
|
|
0 commit comments