Skip to content

Commit 870219b

Browse files
David Erikssonmeta-codesync[bot]
authored andcommitted
Add _supports_batched_models attribute to models that don't support batching (meta-pytorch#3239)
Summary: Pull Request resolved: meta-pytorch#3239 This is needed update `use_model_list` in the stacked diff. Reviewed By: saitcakmak Differential Revision: D97148274 fbshipit-source-id: 146666b5ab93059ba39b2040e0a6a69cd972703b
1 parent 91df5fe commit 870219b

6 files changed

Lines changed: 13 additions & 0 deletions

File tree

botorch/models/fully_bayesian.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,7 @@ class AbstractFullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel
809809

810810
_is_fully_bayesian = True
811811
_is_ensemble = True
812+
_supports_batched_models = False
812813
_pyro_model_class: type[PyroModel] = PyroModel
813814

814815
def __init__(

botorch/models/map_saas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ def __init__(
455455

456456
class EnsembleMapSaasSingleTaskGP(SingleTaskGP):
457457
_is_ensemble = True
458+
_supports_batched_models = False
458459

459460
def __init__(
460461
self,

botorch/models/multitask.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel, FantasizeMixin):
146146
different noise levels for the different tasks.
147147
"""
148148

149+
_supports_batched_models = False
150+
149151
def __init__(
150152
self,
151153
train_X: Tensor,

test/models/test_fully_bayesian.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,9 @@ def _get_mcmc_samples(
347347
mcmc_samples[k] = torch.rand(num_samples, 1, dim, **tkwargs)
348348
return mcmc_samples
349349

350+
def test_supports_batched_models(self) -> None:
351+
self.assertFalse(self.model_cls._supports_batched_models)
352+
350353
def test_raises(self) -> None:
351354
tkwargs = {"device": self.device, "dtype": torch.double}
352355
with self.assertRaisesRegex(

test/models/test_map_saas.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,9 @@ def test_emsemble_map_saas(self) -> None:
541541
self.assertIsInstance(model.outcome_transform, Standardize)
542542
self.assertFalse(hasattr(model, "input_transform"))
543543

544+
def test_ensemble_map_saas_supports_batched_models(self) -> None:
545+
self.assertFalse(EnsembleMapSaasSingleTaskGP._supports_batched_models)
546+
544547
def test_ensemble_map_saas_validation(self) -> None:
545548
with self.assertRaisesRegex(ValueError, "Expected taus to be of shape"):
546549
EnsembleMapSaasSingleTaskGP(

test/models/test_multitask.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ def _gen_kronecker_model_and_data(model_kwargs=None, batch_shape=None, **tkwargs
126126

127127

128128
class TestMultiTaskGP(BotorchTestCase):
129+
def test_supports_batched_models(self) -> None:
130+
self.assertFalse(MultiTaskGP._supports_batched_models)
131+
129132
def test_MultiTaskGP(self) -> None:
130133
bounds = torch.tensor([[0.0, 0.0], [1.0, 1.0]])
131134
for dtype, use_intf, use_octf, task_values, fixed_noise in zip(

0 commit comments

Comments
 (0)