17
17
)
18
18
from botorch .utils .testing import BotorchTestCase
19
19
from gpytorch import settings as gpytorch_settings
20
+ from gpytorch .likelihoods .gaussian_likelihood import GaussianLikelihood
20
21
from gpytorch .mlls import ExactMarginalLogLikelihood , SumMarginalLogLikelihood
22
+ from gpytorch .mlls .marginal_log_likelihood import MarginalLogLikelihood
23
+ from gpytorch .module import Module
24
+ from torch import Tensor
21
25
from torch .utils .data import DataLoader , TensorDataset
22
26
23
27
28
+ # Mock wrapping the __call__ directly is leading to errors like
29
+ # TypeError: super(type, obj): obj must be an instance or subtype of type
30
+ # so, doing this manually here.
31
+ class WrapperLikelihood (GaussianLikelihood ):
32
+ def __init__ (self , base_likelihood : GaussianLikelihood ):
33
+ """A wrapper around a GaussianLikelihood that stores the call args."""
34
+ Module .__init__ (self )
35
+ self .base_likelihood = base_likelihood
36
+ self .call_args = []
37
+
38
+ def __call__ (self , * args , ** kwargs ):
39
+ # Store the train inputs arg for testing.
40
+ self .call_args .append (args [1 ])
41
+ return self .base_likelihood (* args , ** kwargs )
42
+
43
+
44
+ def _get_mlls (
45
+ device : torch .device , wrap_likelihood : bool = False
46
+ ) -> tuple [Tensor , list [MarginalLogLikelihood ]]:
47
+ """Returns the train X, along two MLLs: one for a SingleTaskGP and
48
+ one for a ModelListGP.
49
+
50
+ Args:
51
+ device: The device to use.
52
+ wrap_likelihood: If True, wrap the likelihood in a WrapperLikelihood.
53
+ This is useful for comparing call args later.
54
+ """
55
+ with torch .random .fork_rng ():
56
+ torch .manual_seed (0 )
57
+ # Inputs are not in the unit cube to ensure input transform is applied.
58
+ train_X = torch .linspace (0 , 5 , 10 ).unsqueeze (- 1 )
59
+ train_Y = torch .sin ((2 * pi ) * train_X )
60
+ train_Y = train_Y + 0.1 * torch .randn_like (train_Y )
61
+ mlls = []
62
+ model = SingleTaskGP (
63
+ train_X = train_X ,
64
+ train_Y = train_Y ,
65
+ input_transform = Normalize (d = 1 ),
66
+ outcome_transform = Standardize (m = 1 ),
67
+ )
68
+ if wrap_likelihood :
69
+ model .likelihood = WrapperLikelihood (model .likelihood )
70
+ mll = ExactMarginalLogLikelihood (model .likelihood , model )
71
+ mlls .append (mll .to (device = device , dtype = torch .double ))
72
+
73
+ model = ModelListGP (model , model )
74
+ mll = SumMarginalLogLikelihood (model .likelihood , model )
75
+ mlls .append (mll .to (device = device , dtype = torch .double ))
76
+ return train_X .to (device = device , dtype = torch .double ), mlls
77
+
78
+
24
79
class TestLossClosures (BotorchTestCase ):
25
- def setUp (self ):
26
- super ().setUp ()
27
- with torch .random .fork_rng ():
28
- torch .manual_seed (0 )
29
- train_X = torch .linspace (0 , 1 , 10 ).unsqueeze (- 1 )
30
- train_Y = torch .sin ((2 * pi ) * train_X )
31
- train_Y = train_Y + 0.1 * torch .randn_like (train_Y )
32
-
33
- self .mlls = {}
34
- model = SingleTaskGP (
35
- train_X = train_X ,
36
- train_Y = train_Y ,
37
- input_transform = Normalize (d = 1 ),
38
- outcome_transform = Standardize (m = 1 ),
39
- )
40
- mll = ExactMarginalLogLikelihood (model .likelihood , model )
41
- self .mlls [type (mll ), type (model .likelihood ), type (model )] = mll .to (self .device )
42
-
43
- model = ModelListGP (model , model )
44
- mll = SumMarginalLogLikelihood (model .likelihood , model )
45
- self .mlls [type (mll ), type (model .likelihood ), type (model )] = mll .to (self .device )
46
-
47
- def test_main (self ):
48
- for mll in self .mlls .values ():
80
+ def test_main (self ) -> None :
81
+ for mll in _get_mlls (device = self .device )[1 ]:
49
82
out = mll .model (* mll .model .train_inputs )
50
83
loss = - mll (out , mll .model .train_targets ).sum ()
51
84
loss .backward ()
@@ -63,8 +96,8 @@ def test_main(self):
63
96
self .assertTrue (loss .equal (_loss ))
64
97
self .assertTrue (all (a .equal (b ) for a , b in zip_longest (grads , _grads )))
65
98
66
- def test_data_loader (self ):
67
- for mll in self .mlls . values () :
99
+ def test_data_loader (self ) -> None :
100
+ for mll in _get_mlls ( device = self .device )[ 1 ] :
68
101
if type (mll ) is not ExactMarginalLogLikelihood :
69
102
continue
70
103
@@ -86,3 +119,38 @@ def test_data_loader(self):
86
119
closure = get_loss_closure_with_grads (mll , params , data_loader = loader )
87
120
with self .assertRaisesRegex (TypeError , "Expected .* a batch of tensors" ):
88
121
closure ()
122
+
123
+ def test_with_input_transforms (self ) -> None :
124
+ # This test reproduces the bug reported in issue #2515.
125
+ train_X , mlls = _get_mlls (device = self .device , wrap_likelihood = True )
126
+ for mll in mlls :
127
+ if isinstance (mll , SumMarginalLogLikelihood ):
128
+ # The likelihood is called twice here since it is the same
129
+ # likelihood in both child models.
130
+ likelihood = mll .model .models [0 ].likelihood
131
+ expected_calls1 = 2 # In the closure call.
132
+ expected_calls2 = 6 # Closure + posterior calls.
133
+ else :
134
+ likelihood = mll .model .likelihood
135
+ expected_calls1 = 1 # In the closure call.
136
+ expected_calls2 = 4 # Closure + posterior calls.
137
+ likelihood .call_args = [] # reset since it is shared between the models.
138
+ params = {n : p for n , p in mll .named_parameters () if p .requires_grad }
139
+ # Evaluate the closure to mimic the model fitting process.
140
+ mll .train ()
141
+ closure = get_loss_closure_with_grads (mll , params )
142
+ closure ()
143
+ self .assertEqual (len (likelihood .call_args ), expected_calls1 )
144
+ # Call the model posterior to reproduce post-fitting usage.
145
+ mll .model .posterior (train_X , observation_noise = True )
146
+ # Compare the call args to ensure they're all the same.
147
+ # Likelihood is called twice on model(X) and once for adding the noise.
148
+ self .assertEqual (len (likelihood .call_args ), expected_calls2 )
149
+ arg0 = likelihood .call_args [0 ]
150
+ for i in range (1 , expected_calls2 ):
151
+ argi = likelihood .call_args [i ]
152
+ # The arg may be a tensor or a single element list of the tensor.
153
+ self .assertAllClose (
154
+ arg0 if isinstance (arg0 , Tensor ) else arg0 [0 ],
155
+ argi if isinstance (argi , Tensor ) else argi [0 ],
156
+ )
0 commit comments