Skip to content

Commit 1d5fd01

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Add BetaPrior for correlation parameter priors (#3266)
Summary: Pull Request resolved: #3266 X-link: facebook/Ax#5152 Adds `BetaPrior` — a Beta distribution prior supported on [0, 1], useful for correlation parameters in multi-task GPs. Registers it in Ax's FB serialization registries. Reviewed By: sdaulton Differential Revision: D99841563 fbshipit-source-id: 4064e1d0a303aa554086c1a633072b788212a97a
1 parent 94766a7 commit 1d5fd01

3 files changed

Lines changed: 165 additions & 0 deletions

File tree

botorch/models/utils/priors.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from typing import Callable, Optional
10+
11+
import torch
12+
from gpytorch.priors import Prior
13+
from gpytorch.priors.utils import BUFFERED_PREFIX
14+
from torch import Tensor
15+
from torch.distributions import Beta
16+
from torch.nn import Module as TModule
17+
18+
19+
class BetaPrior(Prior, Beta):
20+
"""Beta Prior parameterized by concentration1 (alpha) and concentration0 (beta).
21+
22+
pdf(x) = x^(alpha - 1) * (1 - x)^(beta - 1) / B(alpha, beta)
23+
24+
where alpha > 0 and beta > 0 are the concentration parameters.
25+
Supported on [0, 1], useful as a prior on correlation parameters.
26+
"""
27+
28+
# Beta.concentration1/concentration0 are @property descriptors (they
29+
# delegate to an internal Dirichlet), so _bufferize_attributes cannot
30+
# delattr them. We store separate buffers with the BUFFERED_PREFIX and
31+
# sync them back after load_state_dict.
32+
_PARAM_NAMES = ("concentration1", "concentration0")
33+
34+
def __init__(
35+
self,
36+
concentration1: float,
37+
concentration0: float,
38+
validate_args: bool = False,
39+
transform: Optional[Callable[[Tensor], Tensor]] = None,
40+
) -> None:
41+
"""Initialize BetaPrior.
42+
43+
Args:
44+
concentration1: Alpha (first concentration) parameter.
45+
concentration0: Beta (second concentration) parameter.
46+
validate_args: Whether to validate input arguments.
47+
transform: Optional transform to apply before computing log_prob.
48+
"""
49+
TModule.__init__(self)
50+
Beta.__init__(
51+
self,
52+
concentration1=concentration1,
53+
concentration0=concentration0,
54+
validate_args=validate_args,
55+
)
56+
for attr in self._PARAM_NAMES:
57+
self.register_buffer(
58+
f"{BUFFERED_PREFIX}{attr}", getattr(self, attr).clone()
59+
)
60+
self._transform = transform
61+
62+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
63+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
64+
# Sync buffered values back into the underlying Dirichlet distribution.
65+
c1 = getattr(self, f"{BUFFERED_PREFIX}concentration1")
66+
c0 = getattr(self, f"{BUFFERED_PREFIX}concentration0")
67+
self._dirichlet.concentration = torch.stack([c1, c0], dim=-1)
68+
69+
def expand(self, batch_shape: torch.Size) -> "BetaPrior":
70+
batch_shape = torch.Size(batch_shape)
71+
return BetaPrior(
72+
self.concentration1.expand(batch_shape),
73+
self.concentration0.expand(batch_shape),
74+
validate_args=self._validate_args,
75+
transform=self._transform,
76+
)

sphinx/source/models.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@ Inducing Point Allocators
209209
:members:
210210
:private-members: _pivoted_cholesky_init
211211

212+
Priors
213+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
214+
.. automodule:: botorch.models.utils.priors
215+
:members:
216+
212217
Other Utilties
213218
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
214219
.. automodule:: botorch.models.utils.assorted

test/models/utils/test_priors.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from botorch.models.utils.priors import BetaPrior
11+
from gpytorch.priors.utils import BUFFERED_PREFIX
12+
from torch.distributions import Beta
13+
14+
15+
class TestBetaPrior(unittest.TestCase):
16+
def test_init(self):
17+
prior = BetaPrior(1.2, 0.9)
18+
self.assertAlmostEqual(prior.concentration1.item(), 1.2)
19+
self.assertAlmostEqual(prior.concentration0.item(), 0.9)
20+
self.assertIsNone(prior._transform)
21+
22+
def test_init_with_transform(self):
23+
prior = BetaPrior(2.0, 3.0, transform=torch.sigmoid)
24+
self.assertIs(prior._transform, torch.sigmoid)
25+
26+
def test_log_prob_matches_torch(self):
27+
prior = BetaPrior(1.2, 0.9)
28+
ref = Beta(torch.tensor(1.2), torch.tensor(0.9))
29+
x = torch.rand(5, 4)
30+
self.assertTrue(torch.allclose(prior.log_prob(x), ref.log_prob(x)))
31+
32+
def test_log_prob_with_transform(self):
33+
def transform(x):
34+
return x.clamp(0.01, 0.99)
35+
36+
prior = BetaPrior(2.0, 2.0, transform=transform)
37+
prior_no_transform = BetaPrior(2.0, 2.0)
38+
x = torch.rand(3, 2)
39+
lp = prior.log_prob(x)
40+
self.assertEqual(lp.shape, torch.Size([3, 2]))
41+
# Values at boundaries should differ due to clamping
42+
x_boundary = torch.tensor([[0.001, 0.999]])
43+
self.assertFalse(
44+
torch.allclose(
45+
prior.log_prob(x_boundary),
46+
prior_no_transform.log_prob(x_boundary),
47+
)
48+
)
49+
50+
def test_log_prob_batch(self):
51+
prior = BetaPrior(1.5, 2.5)
52+
x = torch.rand(7, 3, 2)
53+
lp = prior.log_prob(x)
54+
self.assertEqual(lp.shape, torch.Size([7, 3, 2]))
55+
56+
def test_rsample(self):
57+
prior = BetaPrior(1.2, 0.9)
58+
samples = prior.rsample(torch.Size([10, 5]))
59+
self.assertEqual(samples.shape, torch.Size([10, 5]))
60+
self.assertTrue(torch.all(samples >= 0))
61+
self.assertTrue(torch.all(samples <= 1))
62+
63+
def test_state_dict_roundtrip(self):
64+
prior = BetaPrior(1.2, 0.9)
65+
sd = prior.state_dict()
66+
self.assertIn(f"{BUFFERED_PREFIX}concentration1", sd)
67+
self.assertIn(f"{BUFFERED_PREFIX}concentration0", sd)
68+
69+
prior2 = BetaPrior(999.0, 999.0)
70+
prior2.load_state_dict(sd)
71+
self.assertAlmostEqual(prior2.concentration1.item(), 1.2)
72+
self.assertAlmostEqual(prior2.concentration0.item(), 0.9)
73+
74+
def test_expand(self):
75+
prior = BetaPrior(1.2, 0.9)
76+
expanded = prior.expand(torch.Size([3, 2]))
77+
self.assertEqual(expanded.concentration1.shape, torch.Size([3, 2]))
78+
self.assertEqual(expanded.concentration0.shape, torch.Size([3, 2]))
79+
80+
def test_expand_preserves_transform(self):
81+
prior = BetaPrior(1.2, 0.9, transform=torch.sigmoid)
82+
expanded = prior.expand(torch.Size([3, 2]))
83+
self.assertIs(expanded._transform, torch.sigmoid)
84+
self.assertEqual(expanded._validate_args, prior._validate_args)

0 commit comments

Comments
 (0)