Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,27 @@ def entropy(self) -> ArrayLike:


class Cauchy(Distribution):
r"""Cauchy distribution parameterized by location (:attr:`loc`) and
scale (:attr:`scale`).

The probability density function (PDF) is defined as:

.. math::
f(x; x_0, \gamma) = \frac{1}{\pi \gamma \left[1 +
\left(\frac{x - x_0}{\gamma}\right)^2\right]}

where :math:`x \in \mathbb{R}`, :math:`x_0 \in \mathbb{R}` is the location,
and :math:`\gamma > 0` is the scale. The Cauchy distribution has no finite
mean or variance.

:param loc: Location parameter (:math:`x_0`).
:type loc: ArrayLike
:param scale: Scale parameter (:math:`\gamma`).
:type scale: ArrayLike
:param validate_args: Whether to validate input constraints, defaults to None.
:type validate_args: bool, optional
"""

arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
reparametrized_params = ["loc", "scale"]
Expand All @@ -352,12 +373,32 @@ def __init__(
)

def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
r"""Generates samples using the inverse CDF method via :func:`~jax.random.cauchy`.

:param key: JAX PRNGKey for reproducibility.
:type key: jax.Array
:param sample_shape: The shape of the samples to be generated.
:type sample_shape: tuple[int, ...]
:return: Samples from the Cauchy distribution of shape ``sample_shape + batch_shape``.
:rtype: ArrayLike
"""
assert is_prng_key(key)
eps = random.cauchy(key, shape=sample_shape + self.batch_shape)
return self.loc + eps * self.scale

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
r"""Calculates the log of the probability density function.

.. math::
\log f(x; x_0, \gamma) = -\log(\pi) - \log(\gamma)
- \log\!\left[1 + \left(\frac{x - x_0}{\gamma}\right)^2\right]

:param value: Values at which to evaluate the log density.
:type value: ArrayLike
:return: Log probability density.
:rtype: ArrayLike
"""
return (
-jnp.log(jnp.pi)
- jnp.log(self.scale)
Expand All @@ -366,20 +407,51 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:

@property
def mean(self) -> ArrayLike:
r"""The mean of the Cauchy distribution is undefined.

Returns ``NaN`` for all batch elements.
"""
return jnp.full(self.batch_shape, jnp.nan)

@property
def variance(self) -> ArrayLike:
r"""The variance of the Cauchy distribution is undefined.

Returns ``NaN`` for all batch elements.
"""
return jnp.full(self.batch_shape, jnp.nan)

def cdf(self, value: ArrayLike) -> ArrayLike:
r"""Cumulative distribution function.

.. math::
F(x; x_0, \gamma) = \frac{1}{\pi}\arctan\!\left(\frac{x - x_0}{\gamma}\right)
+ \frac{1}{2}

:param value: Value to evaluate.
:type value: ArrayLike
"""
scaled = (value - self.loc) / self.scale
return jnp.arctan(scaled) / jnp.pi + 0.5

def icdf(self, q: ArrayLike) -> ArrayLike:
r"""Inverse cumulative distribution function (Quantile function).

.. math::
F^{-1}(q; x_0, \gamma) = x_0 + \gamma \tan\!\left[\pi\!\left(q
- \frac{1}{2}\right)\right]

:param q: Probability value in :math:`[0,1]`.
:type q: ArrayLike
"""
return self.loc + self.scale * jnp.tan(jnp.pi * (q - 0.5))

def entropy(self) -> ArrayLike:
r"""Entropy of the Cauchy distribution.

.. math::
H(X) = \log(4\pi\gamma)
"""
return jnp.broadcast_to(jnp.log(4 * np.pi * self.scale), self.batch_shape)


Expand Down Expand Up @@ -589,6 +661,21 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:


class Exponential(Distribution):
r"""Exponential distribution parameterized by rate (:attr:`rate`).

The probability density function (PDF) is defined as:

.. math::
f(x; \lambda) = \lambda e^{-\lambda x}

where :math:`x \geq 0` and :math:`\lambda > 0` is the rate parameter.

:param rate: Rate parameter (:math:`\lambda`), the inverse of the mean.
:type rate: ArrayLike
:param validate_args: Whether to validate input constraints, defaults to None.
:type validate_args: bool, optional
"""

reparametrized_params = ["rate"]
arg_constraints = {"rate": constraints.positive}
support = constraints.positive
Expand All @@ -605,30 +692,79 @@ def __init__(
)

def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
r"""Generates samples by scaling standard exponential draws by the
inverse rate: :math:`X = E / \lambda`, where :math:`E \sim \mathrm{Exp}(1)`.

:param key: JAX PRNGKey for reproducibility.
:type key: jax.Array
:param sample_shape: The shape of the samples to be generated.
:type sample_shape: tuple[int, ...]
:return: Samples from the Exponential distribution of shape ``sample_shape + batch_shape``.
:rtype: ArrayLike
"""
assert is_prng_key(key)
return (
random.exponential(key, shape=sample_shape + self.batch_shape) / self.rate
)

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
r"""Calculates the log of the probability density function.

.. math::
\log f(x; \lambda) = \log \lambda - \lambda x

:param value: Values at which to evaluate the log density.
:type value: ArrayLike
:return: Log probability density.
:rtype: ArrayLike
"""
return jnp.log(self.rate) - self.rate * value

@property
def mean(self) -> ArrayLike:
r"""Calculates the analytical mean.

.. math:: E[X] = \frac{1}{\lambda}
"""
return jnp.reciprocal(self.rate)

@property
def variance(self) -> ArrayLike:
r"""Calculates the analytical variance.

.. math:: \mathrm{Var}(X) = \frac{1}{\lambda^2}
"""
return jnp.reciprocal(self.rate**2)

def cdf(self, value: ArrayLike) -> ArrayLike:
r"""Cumulative distribution function.

.. math::
F(x; \lambda) = 1 - e^{-\lambda x}

:param value: Value to evaluate.
:type value: ArrayLike
"""
return -jnp.expm1(-self.rate * value)

def icdf(self, q: ArrayLike) -> ArrayLike:
r"""Inverse cumulative distribution function (Quantile function).

.. math::
F^{-1}(q; \lambda) = -\frac{\ln(1 - q)}{\lambda}

:param q: Probability value in :math:`[0,1]`.
:type q: ArrayLike
"""
return -jnp.log1p(-q) / self.rate

def entropy(self) -> ArrayLike:
r"""Entropy of the Exponential distribution.

.. math::
H(X) = 1 - \ln \lambda
"""
return 1 - jnp.log(self.rate)


Expand Down Expand Up @@ -2513,6 +2649,26 @@ def infer_shapes(loc, cov_factor, cov_diag):


class Normal(Distribution):
r"""Normal (Gaussian) distribution parameterized by mean (:attr:`loc`) and
standard deviation (:attr:`scale`).

The probability density function (PDF) is defined as:

.. math::
f(x; \mu, \sigma) = \frac{1}{\sigma \sqrt{2\pi}}
\exp\!\left( -\frac{(x - \mu)^2}{2\sigma^2} \right)

where :math:`x \in \mathbb{R}`, :math:`\mu \in \mathbb{R}` is the mean,
and :math:`\sigma > 0` is the standard deviation.

:param loc: Mean of the distribution (:math:`\mu`).
:type loc: ArrayLike
:param scale: Standard deviation of the distribution (:math:`\sigma`).
:type scale: ArrayLike
:param validate_args: Whether to validate input constraints, defaults to None.
:type validate_args: bool, optional
"""

arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
reparametrized_params = ["loc", "scale"]
Expand All @@ -2531,6 +2687,16 @@ def __init__(
)

def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
r"""Generates samples via the reparameterization trick:
:math:`X = \mu + \sigma \epsilon`, where :math:`\epsilon \sim \mathcal{N}(0,1)`.

:param key: JAX PRNGKey for reproducibility.
:type key: jax.Array
:param sample_shape: The shape of the samples to be generated.
:type sample_shape: tuple[int, ...]
:return: Samples from the Normal distribution of shape ``sample_shape + batch_shape``.
:rtype: ArrayLike
"""
assert is_prng_key(key)
eps = random.normal(
key, shape=sample_shape + self.batch_shape + self.event_shape
Expand All @@ -2539,29 +2705,83 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
r"""Calculates the log of the probability density function.

.. math::
\log f(x; \mu, \sigma) = -\frac{(x - \mu)^2}{2\sigma^2}
- \log(\sigma \sqrt{2\pi})

:param value: Values at which to evaluate the log density.
:type value: ArrayLike
:return: Log probability density.
:rtype: ArrayLike
"""
normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale)
value_scaled = (value - self.loc) / self.scale
return -0.5 * value_scaled**2 - normalize_term

def cdf(self, value: ArrayLike) -> ArrayLike:
r"""Cumulative distribution function.

.. math::
F(x; \mu, \sigma) = \Phi\!\left(\frac{x-\mu}{\sigma}\right)

where, :math:`\Phi` is the
`cumulative distribution function of standard normal distribution <https://en.wikipedia.org/wiki/Normal_distribution#Cumulative_distribution_function>`_.
Implementation uses :func:`jax.scipy.special.ndtr` for :math:`\Phi`.

:param value: Value to evaluate.
:type value: ArrayLike
"""
scaled = (value - self.loc) / self.scale
return ndtr(scaled)

def log_cdf(self, value: ArrayLike) -> ArrayLike:
r"""Log of the cumulative distribution function. Implementation
calls :func:`jax.scipy.stats.norm.logcdf`.

:param value: Value to evaluate.
:type value: ArrayLike
"""
return jax_norm.logcdf(value, loc=self.loc, scale=self.scale)

def icdf(self, q: ArrayLike) -> ArrayLike:
r"""Inverse cumulative distribution function (Quantile function).
Comment thread
Qazalbash marked this conversation as resolved.

.. math::
F^{-1}(q; \mu, \sigma) = \mu + \sigma\,\Phi^{-1}(q)

where, :math:`\mathrm{\Phi^{-1}}` is inverse
`cumulative distribution function of standard normal distribution <https://en.wikipedia.org/wiki/Normal_distribution#Cumulative_distribution_function>`_.
Implementation uses :func:`jax.scipy.special.ndtri` for :math:`\Phi^{-1}`.

:param q: Probability value in :math:`[0,1]`.
:type q: ArrayLike
"""
return self.loc + self.scale * ndtri(q)

@property
def mean(self) -> ArrayLike:
r"""Calculates the analytical mean.

.. math:: E[X] = \mu
"""
return jnp.broadcast_to(self.loc, self.batch_shape)

@property
def variance(self) -> ArrayLike:
r"""Calculates the analytical variance.

.. math:: \mathrm{Var}(X) = \sigma^2
"""
return jnp.broadcast_to(self.scale**2, self.batch_shape)

def entropy(self) -> ArrayLike:
r"""Entropy of the Normal distribution.

.. math::
H(X) = \frac{1}{2} \ln(2\pi e \sigma^2)
"""
return jnp.broadcast_to(
(jnp.log(2 * np.pi * self.scale**2) + 1) / 2, self.batch_shape
)
Expand Down
Loading