Skip to content

Commit cb26382

Browse files
DistraxDevDistraxDev
authored andcommitted
Implements the BetaBinomial distribution in distrax._src.distributions.
This is a compound distribution, equivalent to a Beta-Binomial mixture. The implementation follows distrax conventions and uses the TFP implementation as a mathematical reference. PiperOrigin-RevId: 826581660
1 parent d901057 commit cb26382

File tree

2 files changed

+509
-0
lines changed

2 files changed

+509
-0
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""BetaBinomial distribution."""
16+
17+
from typing import Any, Tuple, Union
18+
19+
import chex
20+
from distrax._src.distributions import distribution
21+
from distrax._src.utils import conversion
22+
from distrax._src.utils import math
23+
import jax
24+
import jax.numpy as jnp
25+
from tensorflow_probability.substrates import jax as tfp
26+
27+
28+
tfd = tfp.distributions
29+
30+
Array = chex.Array
31+
Numeric = chex.Numeric
32+
PRNGKey = chex.PRNGKey
33+
EventT = distribution.EventT
34+
35+
36+
class BetaBinomial(distribution.Distribution):
37+
"""Beta-Binomial compound distribution.
38+
39+
The Beta-Binomial distribution is parameterized by `total_count`,
40+
`concentration1` (alpha), and `concentration0` (beta).
41+
It is a compound distribution, equivalent to sampling a probability `p`
42+
from a `Beta(concentration1, concentration0)` distribution, and then
43+
sampling a count from a `Binomial(total_count, p)` distribution.
44+
45+
The probability mass function (pmf) is,
46+
47+
```none
48+
pmf(k; n, a, b) = C(n, k) * Beta(k + a, n - k + b) / Beta(a, b)
49+
```
50+
51+
where:
52+
* `total_count = n`
53+
* `concentration1 = a > 0`
54+
* `concentration0 = b > 0`
55+
* `k` is the number of successes (an integer from 0 to n)
56+
* `C(n, k)` is the binomial coefficient "n choose k".
57+
* `Beta(x, y)` is the Beta function.
58+
"""
59+
60+
equiv_tfp_cls = tfd.BetaBinomial
61+
62+
def __init__(
63+
self,
64+
total_count: Numeric,
65+
concentration1: Numeric,
66+
concentration0: Numeric,
67+
dtype: Union[jnp.dtype, type[Any]] = int,
68+
):
69+
"""Initializes a BetaBinomial distribution.
70+
71+
Args:
72+
total_count: Non-negative floating-point tensor, whose components should
73+
be equal to integer values. The number of trials.
74+
concentration1: Positive floating-point tensor, the alpha parameter of the
75+
Beta prior.
76+
concentration0: Positive floating-point tensor, the beta parameter of the
77+
Beta prior.
78+
dtype: The type of event samples. Defaults to `int`.
79+
"""
80+
super().__init__()
81+
# TFP implementation uses float for total_count, as do lbeta and
82+
# tfd.Binomial.
83+
if not (
84+
jnp.issubdtype(dtype, jnp.integer)
85+
or jnp.issubdtype(dtype, jnp.floating)
86+
):
87+
raise ValueError(
88+
f'The dtype of `{self.name}` must be integer or '
89+
f'floating-point, instead got `{dtype}`.'
90+
)
91+
self._total_count = conversion.as_float_array(total_count)
92+
self._concentration1 = conversion.as_float_array(concentration1)
93+
self._concentration0 = conversion.as_float_array(concentration0)
94+
self._dtype = dtype
95+
96+
@property
97+
def event_shape(self) -> Tuple[int, ...]:
98+
"""See `Distribution.event_shape`."""
99+
return ()
100+
101+
@property
102+
def batch_shape(self) -> Tuple[int, ...]:
103+
"""See `Distribution.batch_shape`."""
104+
return jnp.broadcast_shapes(
105+
self._total_count.shape,
106+
self._concentration1.shape,
107+
self._concentration0.shape,
108+
)
109+
110+
@property
111+
def total_count(self) -> Array:
112+
"""Number of trials."""
113+
return self._total_count
114+
115+
@property
116+
def concentration1(self) -> Array:
117+
"""Concentration parameter associated with a `success` outcome (alpha)."""
118+
return self._concentration1
119+
120+
@property
121+
def concentration0(self) -> Array:
122+
"""Concentration parameter associated with a `failure` outcome (beta)."""
123+
return self._concentration0
124+
125+
def _log_combinations(self, n: Array, k: Array) -> Array:
126+
"""Computes log(C(n, k)) using lbeta."""
127+
# log(C(n, k)) = log(n!) - log(k!) - log((n-k)!)
128+
# = log(Gamma(n+1)) - log(Gamma(k+1)) - log(Gamma(n-k+1))
129+
# Using lbeta: lbeta(x, y) = log(Gamma(x)) + log(Gamma(y)) - log(Gamma(x+y))
130+
# Let x = k + 1, y = n - k + 1.
131+
# lbeta(k+1, n-k+1) = log(Gamma(k+1)) + log(Gamma(n-k+1)) - log(Gamma(n+2))
132+
# We want: log(Gamma(n+1)) - (log(Gamma(k+1)) + log(Gamma(n-k+1)))
133+
# This is: -lbeta(k + 1, n - k + 1) - log(n + 1)
134+
# log(C(n, k)) = -lbeta(k + 1, n - k + 1) - log(n + 1)
135+
# Using lbeta from distrax.utils.math
136+
return -math.log_beta(k + 1.0, n - k + 1.0) - jnp.log(n + 1.0)
137+
138+
def _sample_n(self, key: PRNGKey, n: int) -> Array:
139+
"""See `Distribution._sample_n`."""
140+
key1, key2, key3 = jax.random.split(key, 3)
141+
142+
# Get parameters and broadcast them to (n,) + batch_shape
143+
shape = (n,) + self.batch_shape
144+
total_count = jnp.broadcast_to(self.total_count, shape)
145+
concentration1 = jnp.broadcast_to(self.concentration1, shape)
146+
concentration0 = jnp.broadcast_to(self.concentration0, shape)
147+
148+
# Sample probs ~ Beta(concentration1, concentration0)
149+
# This is done by sampling g1 ~ Gamma(c1, 1) and g2 ~ Gamma(c0, 1)
150+
# and computing probs = g1 / (g1 + g2).
151+
g1 = jax.random.gamma(key1, concentration1)
152+
g2 = jax.random.gamma(key2, concentration0)
153+
154+
g_sum = g1 + g2
155+
# Use 0.5 if g_sum is 0 (which happens if c1=0, c2=0), otherwise g1 / g_sum.
156+
probs = jnp.where(g_sum == 0.0, 0.5, g1 / g_sum)
157+
158+
# Sample counts ~ Binomial(total_count, probs)
159+
samples = tfd.Binomial(total_count=total_count, probs=probs).sample(
160+
seed=key3
161+
)
162+
163+
return samples.astype(self._dtype)
164+
165+
def log_prob(self, value: EventT) -> Array:
166+
"""See `Distribution.log_prob`."""
167+
n = self.total_count
168+
# Cast value to jnp.asarray to ensure it has .astype method for pytype
169+
k = jnp.asarray(value).astype(n.dtype)
170+
c1 = self.concentration1
171+
c0 = self.concentration0
172+
173+
# pmf(k; n, a, b) = C(n, k) * Beta(k + a, n - k + b) / Beta(a, b)
174+
# log_pmf = log(C(n, k)) + log(Beta(k + a, n - k + b)) - log(Beta(a, b))
175+
# log(Beta(x, y)) = log_beta(x, y)
176+
log_comb = self._log_combinations(n, k)
177+
log_beta_comp = math.log_beta(c1 + k, n - k + c0)
178+
log_beta_prior = math.log_beta(c1, c0)
179+
180+
return log_comb + log_beta_comp - log_beta_prior
181+
182+
def mean(self) -> Array:
183+
"""See `Distribution.mean`."""
184+
n = self.total_count
185+
c1 = self.concentration1
186+
c0 = self.concentration0
187+
return n * c1 / (c1 + c0)
188+
189+
def variance(self) -> Array:
190+
"""See `Distribution.variance`."""
191+
n = self.total_count
192+
c1 = self.concentration1
193+
c0 = self.concentration0
194+
c_sum = c1 + c0
195+
# Formula from TFP:
196+
return (n * c1 * c0 * (c_sum + n)) / (c_sum**2 * (c_sum + 1.0))
197+
198+
def __getitem__(self, index) -> 'BetaBinomial':
199+
"""See `Distribution.__getitem__`."""
200+
index = distribution.to_batch_shape_index(self.batch_shape, index)
201+
return BetaBinomial(
202+
total_count=self.total_count[index],
203+
concentration1=self.concentration1[index],
204+
concentration0=self.concentration0[index],
205+
dtype=self._dtype,
206+
)

0 commit comments

Comments
 (0)