|
| 1 | +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""BetaBinomial distribution.""" |
| 16 | + |
| 17 | +from typing import Any, Tuple, Union |
| 18 | + |
| 19 | +import chex |
| 20 | +from distrax._src.distributions import distribution |
| 21 | +from distrax._src.utils import conversion |
| 22 | +from distrax._src.utils import math |
| 23 | +import jax |
| 24 | +import jax.numpy as jnp |
| 25 | +from tensorflow_probability.substrates import jax as tfp |
| 26 | + |
| 27 | + |
| 28 | +tfd = tfp.distributions |
| 29 | + |
| 30 | +Array = chex.Array |
| 31 | +Numeric = chex.Numeric |
| 32 | +PRNGKey = chex.PRNGKey |
| 33 | +EventT = distribution.EventT |
| 34 | + |
| 35 | + |
| 36 | +class BetaBinomial(distribution.Distribution): |
| 37 | + """Beta-Binomial compound distribution. |
| 38 | +
|
| 39 | + The Beta-Binomial distribution is parameterized by `total_count`, |
| 40 | + `concentration1` (alpha), and `concentration0` (beta). |
| 41 | + It is a compound distribution, equivalent to sampling a probability `p` |
| 42 | + from a `Beta(concentration1, concentration0)` distribution, and then |
| 43 | + sampling a count from a `Binomial(total_count, p)` distribution. |
| 44 | +
|
| 45 | + The probability mass function (pmf) is, |
| 46 | +
|
| 47 | + ```none |
| 48 | + pmf(k; n, a, b) = C(n, k) * Beta(k + a, n - k + b) / Beta(a, b) |
| 49 | + ``` |
| 50 | +
|
| 51 | + where: |
| 52 | + * `total_count = n` |
| 53 | + * `concentration1 = a > 0` |
| 54 | + * `concentration0 = b > 0` |
| 55 | + * `k` is the number of successes (an integer from 0 to n) |
| 56 | + * `C(n, k)` is the binomial coefficient "n choose k". |
| 57 | + * `Beta(x, y)` is the Beta function. |
| 58 | + """ |
| 59 | + |
| 60 | + equiv_tfp_cls = tfd.BetaBinomial |
| 61 | + |
| 62 | + def __init__( |
| 63 | + self, |
| 64 | + total_count: Numeric, |
| 65 | + concentration1: Numeric, |
| 66 | + concentration0: Numeric, |
| 67 | + dtype: Union[jnp.dtype, type[Any]] = int, |
| 68 | + ): |
| 69 | + """Initializes a BetaBinomial distribution. |
| 70 | +
|
| 71 | + Args: |
| 72 | + total_count: Non-negative floating-point tensor, whose components should |
| 73 | + be equal to integer values. The number of trials. |
| 74 | + concentration1: Positive floating-point tensor, the alpha parameter of the |
| 75 | + Beta prior. |
| 76 | + concentration0: Positive floating-point tensor, the beta parameter of the |
| 77 | + Beta prior. |
| 78 | + dtype: The type of event samples. Defaults to `int`. |
| 79 | + """ |
| 80 | + super().__init__() |
| 81 | + # TFP implementation uses float for total_count, as do lbeta and |
| 82 | + # tfd.Binomial. |
| 83 | + if not ( |
| 84 | + jnp.issubdtype(dtype, jnp.integer) |
| 85 | + or jnp.issubdtype(dtype, jnp.floating) |
| 86 | + ): |
| 87 | + raise ValueError( |
| 88 | + f'The dtype of `{self.name}` must be integer or ' |
| 89 | + f'floating-point, instead got `{dtype}`.' |
| 90 | + ) |
| 91 | + self._total_count = conversion.as_float_array(total_count) |
| 92 | + self._concentration1 = conversion.as_float_array(concentration1) |
| 93 | + self._concentration0 = conversion.as_float_array(concentration0) |
| 94 | + self._dtype = dtype |
| 95 | + |
| 96 | + @property |
| 97 | + def event_shape(self) -> Tuple[int, ...]: |
| 98 | + """See `Distribution.event_shape`.""" |
| 99 | + return () |
| 100 | + |
| 101 | + @property |
| 102 | + def batch_shape(self) -> Tuple[int, ...]: |
| 103 | + """See `Distribution.batch_shape`.""" |
| 104 | + return jnp.broadcast_shapes( |
| 105 | + self._total_count.shape, |
| 106 | + self._concentration1.shape, |
| 107 | + self._concentration0.shape, |
| 108 | + ) |
| 109 | + |
| 110 | + @property |
| 111 | + def total_count(self) -> Array: |
| 112 | + """Number of trials.""" |
| 113 | + return self._total_count |
| 114 | + |
| 115 | + @property |
| 116 | + def concentration1(self) -> Array: |
| 117 | + """Concentration parameter associated with a `success` outcome (alpha).""" |
| 118 | + return self._concentration1 |
| 119 | + |
| 120 | + @property |
| 121 | + def concentration0(self) -> Array: |
| 122 | + """Concentration parameter associated with a `failure` outcome (beta).""" |
| 123 | + return self._concentration0 |
| 124 | + |
| 125 | + def _log_combinations(self, n: Array, k: Array) -> Array: |
| 126 | + """Computes log(C(n, k)) using lbeta.""" |
| 127 | + # log(C(n, k)) = log(n!) - log(k!) - log((n-k)!) |
| 128 | + # = log(Gamma(n+1)) - log(Gamma(k+1)) - log(Gamma(n-k+1)) |
| 129 | + # Using lbeta: lbeta(x, y) = log(Gamma(x)) + log(Gamma(y)) - log(Gamma(x+y)) |
| 130 | + # Let x = k + 1, y = n - k + 1. |
| 131 | + # lbeta(k+1, n-k+1) = log(Gamma(k+1)) + log(Gamma(n-k+1)) - log(Gamma(n+2)) |
| 132 | + # We want: log(Gamma(n+1)) - (log(Gamma(k+1)) + log(Gamma(n-k+1))) |
| 133 | + # This is: -lbeta(k + 1, n - k + 1) - log(n + 1) |
| 134 | + # log(C(n, k)) = -lbeta(k + 1, n - k + 1) - log(n + 1) |
| 135 | + # Using lbeta from distrax.utils.math |
| 136 | + return -math.log_beta(k + 1.0, n - k + 1.0) - jnp.log(n + 1.0) |
| 137 | + |
| 138 | + def _sample_n(self, key: PRNGKey, n: int) -> Array: |
| 139 | + """See `Distribution._sample_n`.""" |
| 140 | + key1, key2, key3 = jax.random.split(key, 3) |
| 141 | + |
| 142 | + # Get parameters and broadcast them to (n,) + batch_shape |
| 143 | + shape = (n,) + self.batch_shape |
| 144 | + total_count = jnp.broadcast_to(self.total_count, shape) |
| 145 | + concentration1 = jnp.broadcast_to(self.concentration1, shape) |
| 146 | + concentration0 = jnp.broadcast_to(self.concentration0, shape) |
| 147 | + |
| 148 | + # Sample probs ~ Beta(concentration1, concentration0) |
| 149 | + # This is done by sampling g1 ~ Gamma(c1, 1) and g2 ~ Gamma(c0, 1) |
| 150 | + # and computing probs = g1 / (g1 + g2). |
| 151 | + g1 = jax.random.gamma(key1, concentration1) |
| 152 | + g2 = jax.random.gamma(key2, concentration0) |
| 153 | + |
| 154 | + g_sum = g1 + g2 |
| 155 | + # Use 0.5 if g_sum is 0 (which happens if c1=0, c2=0), otherwise g1 / g_sum. |
| 156 | + probs = jnp.where(g_sum == 0.0, 0.5, g1 / g_sum) |
| 157 | + |
| 158 | + # Sample counts ~ Binomial(total_count, probs) |
| 159 | + samples = tfd.Binomial(total_count=total_count, probs=probs).sample( |
| 160 | + seed=key3 |
| 161 | + ) |
| 162 | + |
| 163 | + return samples.astype(self._dtype) |
| 164 | + |
| 165 | + def log_prob(self, value: EventT) -> Array: |
| 166 | + """See `Distribution.log_prob`.""" |
| 167 | + n = self.total_count |
| 168 | + # Cast value to jnp.asarray to ensure it has .astype method for pytype |
| 169 | + k = jnp.asarray(value).astype(n.dtype) |
| 170 | + c1 = self.concentration1 |
| 171 | + c0 = self.concentration0 |
| 172 | + |
| 173 | + # pmf(k; n, a, b) = C(n, k) * Beta(k + a, n - k + b) / Beta(a, b) |
| 174 | + # log_pmf = log(C(n, k)) + log(Beta(k + a, n - k + b)) - log(Beta(a, b)) |
| 175 | + # log(Beta(x, y)) = log_beta(x, y) |
| 176 | + log_comb = self._log_combinations(n, k) |
| 177 | + log_beta_comp = math.log_beta(c1 + k, n - k + c0) |
| 178 | + log_beta_prior = math.log_beta(c1, c0) |
| 179 | + |
| 180 | + return log_comb + log_beta_comp - log_beta_prior |
| 181 | + |
| 182 | + def mean(self) -> Array: |
| 183 | + """See `Distribution.mean`.""" |
| 184 | + n = self.total_count |
| 185 | + c1 = self.concentration1 |
| 186 | + c0 = self.concentration0 |
| 187 | + return n * c1 / (c1 + c0) |
| 188 | + |
| 189 | + def variance(self) -> Array: |
| 190 | + """See `Distribution.variance`.""" |
| 191 | + n = self.total_count |
| 192 | + c1 = self.concentration1 |
| 193 | + c0 = self.concentration0 |
| 194 | + c_sum = c1 + c0 |
| 195 | + # Formula from TFP: |
| 196 | + return (n * c1 * c0 * (c_sum + n)) / (c_sum**2 * (c_sum + 1.0)) |
| 197 | + |
| 198 | + def __getitem__(self, index) -> 'BetaBinomial': |
| 199 | + """See `Distribution.__getitem__`.""" |
| 200 | + index = distribution.to_batch_shape_index(self.batch_shape, index) |
| 201 | + return BetaBinomial( |
| 202 | + total_count=self.total_count[index], |
| 203 | + concentration1=self.concentration1[index], |
| 204 | + concentration0=self.concentration0[index], |
| 205 | + dtype=self._dtype, |
| 206 | + ) |
0 commit comments