Skip to content

Commit d7fe4a5

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Return TorchPosterior for non-MVN distributions in ApproximateGPyTorchModel (#3242)
Summary: When using SingleTaskVariationalGP with non-Gaussian likelihoods (e.g., BetaLikelihood), the posterior distribution is not MultivariateNormal. Now returns TorchPosterior instead of GPyTorchPosterior in that case, since GPyTorchPosterior's methods assume MVN. Closes #3066 Pull Request resolved: #3242 Reviewed By: esantorella Differential Revision: D97317429 Pulled By: saitcakmak fbshipit-source-id: 71792043c3bea6051b7d7dca6329bc15faeeaedb
1 parent 9a296b6 commit d7fe4a5

2 files changed

Lines changed: 38 additions & 3 deletions

File tree

botorch/models/approximate_gp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
InducingPointAllocator,
5050
)
5151
from botorch.posteriors.gpytorch import GPyTorchPosterior
52+
from botorch.posteriors.torch import TorchPosterior
5253
from gpytorch.distributions import MultivariateNormal
5354
from gpytorch.kernels import Kernel
5455
from gpytorch.likelihoods import (
@@ -150,7 +151,7 @@ def posterior(
150151
output_indices: list[int] | None = None,
151152
observation_noise: bool = False,
152153
posterior_transform: PosteriorTransform | None = None,
153-
) -> GPyTorchPosterior:
154+
) -> TorchPosterior:
154155
if output_indices is not None:
155156
raise NotImplementedError( # pragma: no cover
156157
f"{self.__class__.__name__}.posterior does not support output indices."
@@ -170,7 +171,10 @@ def posterior(
170171
if observation_noise:
171172
dist = self.likelihood(dist)
172173

173-
posterior = GPyTorchPosterior(distribution=dist)
174+
if isinstance(dist, MultivariateNormal):
175+
posterior = GPyTorchPosterior(distribution=dist)
176+
else:
177+
posterior = TorchPosterior(distribution=dist)
174178
if hasattr(self, "outcome_transform"):
175179
posterior = self.outcome_transform.untransform_posterior(posterior, X=X)
176180
if posterior_transform is not None:

test/models/test_approximate_gp.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@
2323
GreedyVarianceReduction,
2424
)
2525
from botorch.posteriors import GPyTorchPosterior, TransformedPosterior
26+
from botorch.posteriors.torch import TorchPosterior
2627
from botorch.utils.testing import BotorchTestCase
27-
from gpytorch.likelihoods import GaussianLikelihood, MultitaskGaussianLikelihood
28+
from gpytorch.likelihoods import (
29+
BetaLikelihood,
30+
GaussianLikelihood,
31+
MultitaskGaussianLikelihood,
32+
)
2833
from gpytorch.mlls import VariationalELBO
2934
from gpytorch.variational import (
3035
IndependentMultitaskVariationalStrategy,
@@ -342,6 +347,32 @@ def test_custom_inducing_point_init(self):
342347
self.assertAllClose(model_1_inducing, model_2_inducing)
343348
self.assertFalse(model_1_inducing[0, 0] == model_3_inducing[0, 0])
344349

350+
def test_non_gaussian_likelihood_posterior(self) -> None:
351+
"""Test that non-Gaussian likelihoods return TorchPosterior."""
352+
train_X = torch.rand(10, 1, device=self.device)
353+
model = SingleTaskVariationalGP(
354+
train_X=train_X,
355+
likelihood=BetaLikelihood(),
356+
).to(self.device)
357+
test_X = torch.rand(5, 1, device=self.device)
358+
359+
# Without observation noise, the distribution is MVN (from the GP),
360+
# so it should return GPyTorchPosterior.
361+
posterior = model.posterior(test_X, observation_noise=False)
362+
self.assertIsInstance(posterior, GPyTorchPosterior)
363+
364+
# With observation noise, the likelihood transforms the MVN into a
365+
# Beta distribution, so it should return TorchPosterior.
366+
posterior = model.posterior(test_X, observation_noise=True)
367+
self.assertIsInstance(posterior, TorchPosterior)
368+
self.assertNotIsInstance(posterior, GPyTorchPosterior)
369+
370+
# Verify that sampling works with the TorchPosterior and that the
371+
# shape doesn't have a spurious trailing dim from GPyTorchPosterior.
372+
# GPyTorchPosterior.rsample would unsqueeze(-1), adding an extra dim.
373+
samples = posterior.rsample(sample_shape=torch.Size([2]))
374+
self.assertEqual(samples.shape, torch.Size([2, 10, 5]))
375+
345376
def test_input_transform(self) -> None:
346377
train_X = torch.linspace(1, 3, 10, dtype=torch.double)[:, None]
347378
y = -3 * train_X + 5

0 commit comments

Comments
 (0)