|
8 | 8 | from unittest import mock |
9 | 9 |
|
10 | 10 | import torch |
11 | | -from botorch.acquisition.cached_cholesky import CachedCholeskyMCSamplerMixin |
| 11 | +from botorch.acquisition.cached_cholesky import ( |
| 12 | + CachedCholeskyMCSamplerMixin, |
| 13 | + supports_cache_root, |
| 14 | +) |
12 | 15 | from botorch.acquisition.monte_carlo import MCAcquisitionFunction |
13 | 16 | from botorch.acquisition.objective import GenericMCObjective, MCAcquisitionObjective |
14 | 17 | from botorch.exceptions.warnings import BotorchWarning |
15 | 18 | from botorch.models import SingleTaskGP |
16 | 19 | from botorch.models.deterministic import GenericDeterministicModel |
17 | 20 | from botorch.models.higher_order_gp import HigherOrderGP |
18 | 21 | from botorch.models.model import Model, ModelList |
| 22 | +from botorch.models.model_list_gp_regression import ModelListGP |
| 23 | +from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP |
19 | 24 | from botorch.models.transforms.outcome import Log |
20 | 25 | from botorch.sampling.normal import IIDNormalSampler, MCSampler |
21 | 26 | from botorch.utils.low_rank import extract_batch_covar |
@@ -146,6 +151,89 @@ def test_cache_root_decomposition(self): |
146 | 151 | mock_cholesky.assert_called_once() |
147 | 152 | self.assertTrue(torch.equal(baseline_L_acqf, baseline_L)) |
148 | 153 |
|
| 154 | + def test_supports_cache_root_opt_out(self): |
| 155 | + """Test that models can opt out of cache_root via _supports_cache_root. |
| 156 | +
|
| 157 | + Models with low-rank kernels (e.g., SphericalLinearSingleTaskGP using |
| 158 | + LinearPredictionStrategy) are incompatible with cache_root because |
| 159 | + base_samples are generated for rank r < n. These models set |
| 160 | + _supports_cache_root = False so that cache_root is automatically |
| 161 | + disabled. |
| 162 | + """ |
| 163 | + tkwargs = {"device": self.device} |
| 164 | + for dtype in (torch.float, torch.double): |
| 165 | + with self.subTest(dtype=dtype): |
| 166 | + tkwargs["dtype"] = dtype |
| 167 | + |
| 168 | + # Standard models support cache_root by default |
| 169 | + stgp = SingleTaskGP( |
| 170 | + torch.zeros(2, 1, **tkwargs), torch.zeros(2, 1, **tkwargs) |
| 171 | + ) |
| 172 | + self.assertTrue(supports_cache_root(stgp)) |
| 173 | + |
| 174 | + # Models with _supports_cache_root = False do not |
| 175 | + stgp._supports_cache_root = False |
| 176 | + self.assertFalse(supports_cache_root(stgp)) |
| 177 | + |
| 178 | + # This propagates through ModelListGP |
| 179 | + stgp2 = SingleTaskGP( |
| 180 | + torch.zeros(2, 1, **tkwargs), torch.zeros(2, 1, **tkwargs) |
| 181 | + ) |
| 182 | + stgp2._supports_cache_root = False |
| 183 | + model_list = ModelListGP(stgp2) |
| 184 | + self.assertFalse(supports_cache_root(model_list)) |
| 185 | + |
| 186 | + # CachedCholeskyMCSamplerMixin respects the opt-out |
| 187 | + sampler = IIDNormalSampler(sample_shape=torch.Size([2])) |
| 188 | + acqf = DummyCachedCholeskyAcqf( |
| 189 | + model=stgp, |
| 190 | + sampler=sampler, |
| 191 | + ) |
| 192 | + self.assertFalse(acqf._cache_root) |
| 193 | + |
| 194 | + # Explicitly passing cache_root=True warns and gets disabled |
| 195 | + with self.assertWarnsRegex(RuntimeWarning, "cache_root"): |
| 196 | + acqf = DummyCachedCholeskyAcqf( |
| 197 | + model=stgp, |
| 198 | + sampler=sampler, |
| 199 | + cache_root=True, |
| 200 | + ) |
| 201 | + self.assertFalse(acqf._cache_root) |
| 202 | + |
| 203 | + def test_unsupported_models_have_supports_cache_root_false(self): |
| 204 | + """Test that MultiTaskGP, KroneckerMultiTaskGP, and HigherOrderGP |
| 205 | + set _supports_cache_root = False as a class attribute.""" |
| 206 | + # Check the class attribute directly |
| 207 | + self.assertFalse(MultiTaskGP._supports_cache_root) |
| 208 | + self.assertFalse(KroneckerMultiTaskGP._supports_cache_root) |
| 209 | + self.assertFalse(HigherOrderGP._supports_cache_root) |
| 210 | + |
| 211 | + # Check that instances also have the attribute set to False |
| 212 | + tkwargs = {"device": self.device, "dtype": torch.double} |
| 213 | + |
| 214 | + # MultiTaskGP |
| 215 | + train_X = torch.cat( |
| 216 | + [torch.rand(5, 1, **tkwargs), torch.zeros(5, 1, **tkwargs)], dim=-1 |
| 217 | + ) |
| 218 | + train_Y = torch.rand(5, 1, **tkwargs) |
| 219 | + mtgp = MultiTaskGP(train_X, train_Y, task_feature=-1) |
| 220 | + self.assertFalse(mtgp._supports_cache_root) |
| 221 | + self.assertFalse(supports_cache_root(mtgp)) |
| 222 | + |
| 223 | + # KroneckerMultiTaskGP |
| 224 | + train_X = torch.rand(5, 2, **tkwargs) |
| 225 | + train_Y = torch.rand(5, 2, **tkwargs) |
| 226 | + kmtgp = KroneckerMultiTaskGP(train_X, train_Y) |
| 227 | + self.assertFalse(kmtgp._supports_cache_root) |
| 228 | + self.assertFalse(supports_cache_root(kmtgp)) |
| 229 | + |
| 230 | + # HigherOrderGP |
| 231 | + train_X = torch.rand(5, 2, **tkwargs) |
| 232 | + train_Y = torch.rand(5, 1, 1, **tkwargs) |
| 233 | + hogp = HigherOrderGP(train_X, train_Y) |
| 234 | + self.assertFalse(hogp._supports_cache_root) |
| 235 | + self.assertFalse(supports_cache_root(hogp)) |
| 236 | + |
149 | 237 | def test_get_f_X_samples(self): |
150 | 238 | sample_cached_cholesky_path = ( |
151 | 239 | "botorch.acquisition.cached_cholesky.sample_cached_cholesky" |
|
0 commit comments