Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 97 additions & 7 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,25 @@ def icdf(self, value: ArrayLike) -> ArrayLike:


class Beta(Distribution):
r"""Beta distribution parameterized by concentration parameters alpha (:attr:`concentration1`)
and beta (:attr:`concentration0`), on the unit interval :math:`[0,1]`.

The probability density function (PDF) is defined as:

.. math::
f(x; \alpha, \beta) = \frac{x^{\alpha - 1} (1 - x)^{\beta - 1}}{\text{B}(\alpha, \beta)}

Where, :math:`x \in [0, 1]`, :math:`\alpha > 0`, :math:`\beta > 0`,
and :math:`\text{B}(\alpha, \beta)` is the Beta function.

:param concentration1: Alpha parameter (1st shape parameter).
:type concentration1: ArrayLike
:param concentration0: Beta parameter (2nd shape parameter).
:type concentration0: ArrayLike
:param validate_args: Whether to validate input constraints, defaults to None.
:type validate_args: bool, optional
"""

arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
Expand Down Expand Up @@ -206,11 +225,36 @@ def __init__(
)

def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
r"""Generates samples from the distribution using the underlying Dirichlet implementation.

Since a :math:`\mathrm{Beta}(\alpha, \beta)` distribution is equivalent to a
2-category :math:`\mathrm{Dirichlet}(\alpha, \beta)`, this method samples from
the Dirichlet and slices the first component.

:param key: JAX PRNGKey for reproducibility.
:type key: jax.Array
:param sample_shape: The shape of the samples to be generated.
:type sample_shape: tuple[int, ...]
:return: Samples from the Beta distribution of shape ``sample_shape + batch_shape``.
:rtype: ArrayLike
"""
assert is_prng_key(key)
return self._dirichlet.sample(key, sample_shape)[..., 0]

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
r"""Calculates the log of the probability density function.

To avoid `NaN` gradients at the boundaries :math:`x=0` or :math:`x=1`, this
implementation masks boundary values with a safe constant (0.5) during the
differentiation path. The forward pass value is then corrected using
:func:`~jax.lax.stop_gradient` to ensure numerical stability without sacrificing accuracy.

:param value: Values at which to evaluate the log density.
:type value: ArrayLike
:return: Log probability density.
:rtype: ArrayLike
"""
# Use double-where trick to avoid NaN gradients at boundary conditions
# Reference: https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
is_boundary = (value == 0.0) | (value == 1.0)
Expand Down Expand Up @@ -238,20 +282,48 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:

@property
def mean(self) -> ArrayLike:
r"""Calculates the analytical mean.

.. math:: E[X] = \frac{\alpha}{\alpha + \beta}
"""
return self.concentration1 / (self.concentration1 + self.concentration0)

@property
def variance(self) -> ArrayLike:
r"""Calculates the analytical variance.

.. math:: Var(X) = \frac{\alpha \beta}{(\alpha + \beta)^2 (\alpha + \beta + 1)}
"""
total = self.concentration1 + self.concentration0
return self.concentration1 * self.concentration0 / (total**2 * (total + 1))

def cdf(self, value: ArrayLike) -> ArrayLike:
r"""Cumulative distribution function using the regularized incomplete beta function.

.. math:: I_x(\alpha, \beta) = \frac{\text{B}(x; \alpha, \beta)}{\text{B}(\alpha, \beta)}

:param value: Value to evaluate.
:type value: ArrayLike
"""
return betainc(self.concentration1, self.concentration0, value)

def icdf(self, q: ArrayLike) -> ArrayLike:
r"""Inverse cumulative distribution function (Quantile function).

:param q: Probability value in :math:`[0,1]`.
:type q: ArrayLike
"""
return betaincinv(self.concentration1, self.concentration0, q)

def entropy(self) -> ArrayLike:
r"""Entropy of the Beta distribution.

.. math::
H(X) = \ln \text{B}(\alpha, \beta) - (\alpha - 1)\psi(\alpha) -
(\beta - 1)\psi(\beta) + (\alpha + \beta - 2)\psi(\alpha + \beta)

where :math:`\psi` is the digamma function.
"""
total = self.concentration0 + self.concentration1
return (
betaln(self.concentration0, self.concentration1)
Expand Down Expand Up @@ -2842,14 +2914,32 @@ def entropy(self) -> ArrayLike:


class BetaProportion(Beta):
"""
The BetaProportion distribution is a reparameterization of the conventional
Beta distribution in terms of a the variate mean and a
precision parameter.
r"""Beta distribution reparameterized in terms of a mean (:attr:`mean`) and a
precision (:attr:`concentration`).

Given mean :math:`\mu` and precision :math:`\phi`, the standard Beta
parameters are derived as:

.. math::
\alpha = \mu \phi, \quad \beta = (1 - \mu) \phi

The resulting PDF is:

.. math::
f(x; \mu, \phi) =
\frac{x^{\mu\phi - 1} (1 - x)^{(1 - \mu)\phi - 1}}{\text{B}(\mu\phi, (1 - \mu)\phi)}

**Reference**

Ferrari, Silvia, and Francisco Cribari-Neto. "Beta regression for modelling
rates and proportions." *Journal of Applied Statistics* 31.7 (2004): 799-815.

**Reference:**
`Beta regression for modelling rates and proportion`, Ferrari Silvia, and
Francisco Cribari-Neto. Journal of Applied Statistics 31.7 (2004): 799-815.
:param mean: Mean of the distribution, restricted to the open interval (0, 1).
:type mean: ArrayLike
:param concentration: Precision parameter (:math:`\phi`), must be positive.
:type concentration: ArrayLike
:param validate_args: Whether to validate input constraints, defaults to None.
:type validate_args: bool, optional
"""

arg_constraints = {
Expand Down
Loading