Skip to content

Commit 4decc06

Browse files
David Erikssonmeta-codesync[bot]
authored andcommitted
Support cache_root for low-rank kernels (#3223)
Summary: Pull Request resolved: #3223 This allows turning off `cache_root` for models that don't support it. Reviewed By: saitcakmak Differential Revision: D95317067 fbshipit-source-id: 6b2cff028c008998cf7e4fb7054f8f7447e0d81f
1 parent 870219b commit 4decc06

4 files changed

Lines changed: 100 additions & 7 deletions

File tree

botorch/acquisition/cached_cholesky.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from botorch.models.higher_order_gp import HigherOrderGP
2121
from botorch.models.model import Model
2222
from botorch.models.model_list_gp_regression import ModelListGP
23-
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
2423
from botorch.posteriors.gpytorch import GPyTorchPosterior
2524
from botorch.posteriors.posterior import Posterior
2625
from botorch.sampling.base import MCSampler
@@ -39,10 +38,11 @@ def supports_cache_root(model: Model) -> bool:
3938
"""
4039
if isinstance(model, ModelListGP):
4140
return all(supports_cache_root(m) for m in model.models)
41+
# Allow models to explicitly opt out of cache_root support.
42+
if getattr(model, "_supports_cache_root", True) is False:
43+
return False
4244
# Multi task models and non-GPyTorch models are not supported.
43-
if isinstance(
44-
model, (MultiTaskGP, KroneckerMultiTaskGP, HigherOrderGP)
45-
) or not isinstance(model, GPyTorchModel):
45+
if not isinstance(model, GPyTorchModel):
4646
return False
4747
# Models that return a TransformedPosterior are not supported.
4848
if hasattr(model, "outcome_transform") and (not model.outcome_transform._is_linear):

botorch/models/higher_order_gp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,14 @@ class HigherOrderGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin):
144144
r"""
145145
A model for high-dimensional output regression.
146146
147-
As described in [Zhe2019hogp]_. Higher-order means that the predictions
147+
As described in [Zhe2019hogp]_. "Higher-order" means that the predictions
148148
are matrices (tensors) with at least two dimensions, such as images or
149149
grids of images, or measurements taken from a region of at least two
150150
dimensions.
151151
The posterior uses Matheron's rule [Doucet2010sampl]_
152152
as described in [Maddox2021bohdo]_.
153153
154-
``HigherOrderGP`` differs from a "vector multi-output model in that it uses
154+
``HigherOrderGP`` differs from a "vector" multi-output model in that it uses
155155
Kronecker algebra to obtain parsimonious covariance matrices for these
156156
outputs (see ``KroneckerMultiTaskGP`` for more information). For example,
157157
imagine a 10 x 20 x 30 grid of images. If we were to vectorize the
@@ -177,6 +177,8 @@ class HigherOrderGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin):
177177
>>> samples = model.posterior(test_X).rsample()
178178
"""
179179

180+
_supports_cache_root = False
181+
180182
def __init__(
181183
self,
182184
train_X: Tensor,

botorch/models/multitask.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel, FantasizeMixin):
147147
"""
148148

149149
_supports_batched_models = False
150+
_supports_cache_root = False
150151

151152
def __init__(
152153
self,
@@ -564,6 +565,8 @@ class KroneckerMultiTaskGP(ExactGP, GPyTorchModel, FantasizeMixin):
564565
>>> model = KroneckerMultiTaskGP(train_X, train_Y)
565566
"""
566567

568+
_supports_cache_root = False
569+
567570
def __init__(
568571
self,
569572
train_X: Tensor,

test/acquisition/test_cached_cholesky.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,19 @@
88
from unittest import mock
99

1010
import torch
11-
from botorch.acquisition.cached_cholesky import CachedCholeskyMCSamplerMixin
11+
from botorch.acquisition.cached_cholesky import (
12+
CachedCholeskyMCSamplerMixin,
13+
supports_cache_root,
14+
)
1215
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
1316
from botorch.acquisition.objective import GenericMCObjective, MCAcquisitionObjective
1417
from botorch.exceptions.warnings import BotorchWarning
1518
from botorch.models import SingleTaskGP
1619
from botorch.models.deterministic import GenericDeterministicModel
1720
from botorch.models.higher_order_gp import HigherOrderGP
1821
from botorch.models.model import Model, ModelList
22+
from botorch.models.model_list_gp_regression import ModelListGP
23+
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
1924
from botorch.models.transforms.outcome import Log
2025
from botorch.sampling.normal import IIDNormalSampler, MCSampler
2126
from botorch.utils.low_rank import extract_batch_covar
@@ -146,6 +151,89 @@ def test_cache_root_decomposition(self):
146151
mock_cholesky.assert_called_once()
147152
self.assertTrue(torch.equal(baseline_L_acqf, baseline_L))
148153

154+
def test_supports_cache_root_opt_out(self):
155+
"""Test that models can opt out of cache_root via _supports_cache_root.
156+
157+
Models with low-rank kernels (e.g., SphericalLinearSingleTaskGP using
158+
LinearPredictionStrategy) are incompatible with cache_root because
159+
base_samples are generated for rank r < n. These models set
160+
_supports_cache_root = False so that cache_root is automatically
161+
disabled.
162+
"""
163+
tkwargs = {"device": self.device}
164+
for dtype in (torch.float, torch.double):
165+
with self.subTest(dtype=dtype):
166+
tkwargs["dtype"] = dtype
167+
168+
# Standard models support cache_root by default
169+
stgp = SingleTaskGP(
170+
torch.zeros(2, 1, **tkwargs), torch.zeros(2, 1, **tkwargs)
171+
)
172+
self.assertTrue(supports_cache_root(stgp))
173+
174+
# Models with _supports_cache_root = False do not
175+
stgp._supports_cache_root = False
176+
self.assertFalse(supports_cache_root(stgp))
177+
178+
# This propagates through ModelListGP
179+
stgp2 = SingleTaskGP(
180+
torch.zeros(2, 1, **tkwargs), torch.zeros(2, 1, **tkwargs)
181+
)
182+
stgp2._supports_cache_root = False
183+
model_list = ModelListGP(stgp2)
184+
self.assertFalse(supports_cache_root(model_list))
185+
186+
# CachedCholeskyMCSamplerMixin respects the opt-out
187+
sampler = IIDNormalSampler(sample_shape=torch.Size([2]))
188+
acqf = DummyCachedCholeskyAcqf(
189+
model=stgp,
190+
sampler=sampler,
191+
)
192+
self.assertFalse(acqf._cache_root)
193+
194+
# Explicitly passing cache_root=True warns and gets disabled
195+
with self.assertWarnsRegex(RuntimeWarning, "cache_root"):
196+
acqf = DummyCachedCholeskyAcqf(
197+
model=stgp,
198+
sampler=sampler,
199+
cache_root=True,
200+
)
201+
self.assertFalse(acqf._cache_root)
202+
203+
def test_unsupported_models_have_supports_cache_root_false(self):
204+
"""Test that MultiTaskGP, KroneckerMultiTaskGP, and HigherOrderGP
205+
set _supports_cache_root = False as a class attribute."""
206+
# Check the class attribute directly
207+
self.assertFalse(MultiTaskGP._supports_cache_root)
208+
self.assertFalse(KroneckerMultiTaskGP._supports_cache_root)
209+
self.assertFalse(HigherOrderGP._supports_cache_root)
210+
211+
# Check that instances also have the attribute set to False
212+
tkwargs = {"device": self.device, "dtype": torch.double}
213+
214+
# MultiTaskGP
215+
train_X = torch.cat(
216+
[torch.rand(5, 1, **tkwargs), torch.zeros(5, 1, **tkwargs)], dim=-1
217+
)
218+
train_Y = torch.rand(5, 1, **tkwargs)
219+
mtgp = MultiTaskGP(train_X, train_Y, task_feature=-1)
220+
self.assertFalse(mtgp._supports_cache_root)
221+
self.assertFalse(supports_cache_root(mtgp))
222+
223+
# KroneckerMultiTaskGP
224+
train_X = torch.rand(5, 2, **tkwargs)
225+
train_Y = torch.rand(5, 2, **tkwargs)
226+
kmtgp = KroneckerMultiTaskGP(train_X, train_Y)
227+
self.assertFalse(kmtgp._supports_cache_root)
228+
self.assertFalse(supports_cache_root(kmtgp))
229+
230+
# HigherOrderGP
231+
train_X = torch.rand(5, 2, **tkwargs)
232+
train_Y = torch.rand(5, 1, 1, **tkwargs)
233+
hogp = HigherOrderGP(train_X, train_Y)
234+
self.assertFalse(hogp._supports_cache_root)
235+
self.assertFalse(supports_cache_root(hogp))
236+
149237
def test_get_f_X_samples(self):
150238
sample_cached_cholesky_path = (
151239
"botorch.acquisition.cached_cholesky.sample_cached_cholesky"

0 commit comments

Comments
 (0)