Skip to content

Commit 8dbe44a

Browse files
kimjune01Qazalbash
andauthored
doc(gh-2187): add math explanations for Normal, Cauchy, and Exponential distributions (#2188)
* doc: add math explanations for Normal, Cauchy, and Exponential distributions (#2187) Add LaTeX math explanations, parameter documentation, and method docstrings following the format established in #2185 (Beta/BetaProportion). Covers PDF, log_prob, CDF, ICDF, mean, variance, entropy, and sample for each distribution. Cauchy docstrings note the undefined mean/variance. * Address review feedback: add log_cdf equation and use \frac - Add equation for Normal.log_cdf with note about jax.scipy.stats.norm.logcdf implementation - Replace \tfrac with \frac in Cauchy.icdf equation Addresses review comments from @Qazalbash in PR #2188 * add icdf equation for Normal distribution * fix: documenting equations as per the implementation --------- Co-authored-by: Meesum Qazalbash <meesumqazalbash@gmail.com>
1 parent db5eff9 commit 8dbe44a

1 file changed

Lines changed: 220 additions & 0 deletions

File tree

numpyro/distributions/continuous.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,27 @@ def entropy(self) -> ArrayLike:
334334

335335

336336
class Cauchy(Distribution):
337+
r"""Cauchy distribution parameterized by location (:attr:`loc`) and
338+
scale (:attr:`scale`).
339+
340+
The probability density function (PDF) is defined as:
341+
342+
.. math::
343+
f(x; x_0, \gamma) = \frac{1}{\pi \gamma \left[1 +
344+
\left(\frac{x - x_0}{\gamma}\right)^2\right]}
345+
346+
where :math:`x \in \mathbb{R}`, :math:`x_0 \in \mathbb{R}` is the location,
347+
and :math:`\gamma > 0` is the scale. The Cauchy distribution has no finite
348+
mean or variance.
349+
350+
:param loc: Location parameter (:math:`x_0`).
351+
:type loc: ArrayLike
352+
:param scale: Scale parameter (:math:`\gamma`).
353+
:type scale: ArrayLike
354+
:param validate_args: Whether to validate input constraints, defaults to None.
355+
:type validate_args: bool, optional
356+
"""
357+
337358
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
338359
support = constraints.real
339360
reparametrized_params = ["loc", "scale"]
@@ -352,12 +373,32 @@ def __init__(
352373
)
353374

354375
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
376+
r"""Generates samples using the inverse CDF method via :func:`~jax.random.cauchy`.
377+
378+
:param key: JAX PRNGKey for reproducibility.
379+
:type key: jax.Array
380+
:param sample_shape: The shape of the samples to be generated.
381+
:type sample_shape: tuple[int, ...]
382+
:return: Samples from the Cauchy distribution of shape ``sample_shape + batch_shape``.
383+
:rtype: ArrayLike
384+
"""
355385
assert is_prng_key(key)
356386
eps = random.cauchy(key, shape=sample_shape + self.batch_shape)
357387
return self.loc + eps * self.scale
358388

359389
@validate_sample
360390
def log_prob(self, value: ArrayLike) -> ArrayLike:
391+
r"""Calculates the log of the probability density function.
392+
393+
.. math::
394+
\log f(x; x_0, \gamma) = -\log(\pi) - \log(\gamma)
395+
- \log\!\left[1 + \left(\frac{x - x_0}{\gamma}\right)^2\right]
396+
397+
:param value: Values at which to evaluate the log density.
398+
:type value: ArrayLike
399+
:return: Log probability density.
400+
:rtype: ArrayLike
401+
"""
361402
return (
362403
-jnp.log(jnp.pi)
363404
- jnp.log(self.scale)
@@ -366,20 +407,51 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:
366407

367408
@property
368409
def mean(self) -> ArrayLike:
410+
r"""The mean of the Cauchy distribution is undefined.
411+
412+
Returns ``NaN`` for all batch elements.
413+
"""
369414
return jnp.full(self.batch_shape, jnp.nan)
370415

371416
@property
372417
def variance(self) -> ArrayLike:
418+
r"""The variance of the Cauchy distribution is undefined.
419+
420+
Returns ``NaN`` for all batch elements.
421+
"""
373422
return jnp.full(self.batch_shape, jnp.nan)
374423

375424
def cdf(self, value: ArrayLike) -> ArrayLike:
425+
r"""Cumulative distribution function.
426+
427+
.. math::
428+
F(x; x_0, \gamma) = \frac{1}{\pi}\arctan\!\left(\frac{x - x_0}{\gamma}\right)
429+
+ \frac{1}{2}
430+
431+
:param value: Value to evaluate.
432+
:type value: ArrayLike
433+
"""
376434
scaled = (value - self.loc) / self.scale
377435
return jnp.arctan(scaled) / jnp.pi + 0.5
378436

379437
def icdf(self, q: ArrayLike) -> ArrayLike:
438+
r"""Inverse cumulative distribution function (Quantile function).
439+
440+
.. math::
441+
F^{-1}(q; x_0, \gamma) = x_0 + \gamma \tan\!\left[\pi\!\left(q
442+
- \frac{1}{2}\right)\right]
443+
444+
:param q: Probability value in :math:`[0,1]`.
445+
:type q: ArrayLike
446+
"""
380447
return self.loc + self.scale * jnp.tan(jnp.pi * (q - 0.5))
381448

382449
def entropy(self) -> ArrayLike:
450+
r"""Entropy of the Cauchy distribution.
451+
452+
.. math::
453+
H(X) = \log(4\pi\gamma)
454+
"""
383455
return jnp.broadcast_to(jnp.log(4 * np.pi * self.scale), self.batch_shape)
384456

385457

@@ -589,6 +661,21 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:
589661

590662

591663
class Exponential(Distribution):
664+
r"""Exponential distribution parameterized by rate (:attr:`rate`).
665+
666+
The probability density function (PDF) is defined as:
667+
668+
.. math::
669+
f(x; \lambda) = \lambda e^{-\lambda x}
670+
671+
where :math:`x \geq 0` and :math:`\lambda > 0` is the rate parameter.
672+
673+
:param rate: Rate parameter (:math:`\lambda`), the inverse of the mean.
674+
:type rate: ArrayLike
675+
:param validate_args: Whether to validate input constraints, defaults to None.
676+
:type validate_args: bool, optional
677+
"""
678+
592679
reparametrized_params = ["rate"]
593680
arg_constraints = {"rate": constraints.positive}
594681
support = constraints.positive
@@ -605,30 +692,79 @@ def __init__(
605692
)
606693

607694
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
695+
r"""Generates samples by scaling standard exponential draws by the
696+
inverse rate: :math:`X = E / \lambda`, where :math:`E \sim \mathrm{Exp}(1)`.
697+
698+
:param key: JAX PRNGKey for reproducibility.
699+
:type key: jax.Array
700+
:param sample_shape: The shape of the samples to be generated.
701+
:type sample_shape: tuple[int, ...]
702+
:return: Samples from the Exponential distribution of shape ``sample_shape + batch_shape``.
703+
:rtype: ArrayLike
704+
"""
608705
assert is_prng_key(key)
609706
return (
610707
random.exponential(key, shape=sample_shape + self.batch_shape) / self.rate
611708
)
612709

613710
@validate_sample
614711
def log_prob(self, value: ArrayLike) -> ArrayLike:
712+
r"""Calculates the log of the probability density function.
713+
714+
.. math::
715+
\log f(x; \lambda) = \log \lambda - \lambda x
716+
717+
:param value: Values at which to evaluate the log density.
718+
:type value: ArrayLike
719+
:return: Log probability density.
720+
:rtype: ArrayLike
721+
"""
615722
return jnp.log(self.rate) - self.rate * value
616723

617724
@property
618725
def mean(self) -> ArrayLike:
726+
r"""Calculates the analytical mean.
727+
728+
.. math:: E[X] = \frac{1}{\lambda}
729+
"""
619730
return jnp.reciprocal(self.rate)
620731

621732
@property
622733
def variance(self) -> ArrayLike:
734+
r"""Calculates the analytical variance.
735+
736+
.. math:: \mathrm{Var}(X) = \frac{1}{\lambda^2}
737+
"""
623738
return jnp.reciprocal(self.rate**2)
624739

625740
def cdf(self, value: ArrayLike) -> ArrayLike:
741+
r"""Cumulative distribution function.
742+
743+
.. math::
744+
F(x; \lambda) = 1 - e^{-\lambda x}
745+
746+
:param value: Value to evaluate.
747+
:type value: ArrayLike
748+
"""
626749
return -jnp.expm1(-self.rate * value)
627750

628751
def icdf(self, q: ArrayLike) -> ArrayLike:
752+
r"""Inverse cumulative distribution function (Quantile function).
753+
754+
.. math::
755+
F^{-1}(q; \lambda) = -\frac{\ln(1 - q)}{\lambda}
756+
757+
:param q: Probability value in :math:`[0,1]`.
758+
:type q: ArrayLike
759+
"""
629760
return -jnp.log1p(-q) / self.rate
630761

631762
def entropy(self) -> ArrayLike:
763+
r"""Entropy of the Exponential distribution.
764+
765+
.. math::
766+
H(X) = 1 - \ln \lambda
767+
"""
632768
return 1 - jnp.log(self.rate)
633769

634770

@@ -2513,6 +2649,26 @@ def infer_shapes(loc, cov_factor, cov_diag):
25132649

25142650

25152651
class Normal(Distribution):
2652+
r"""Normal (Gaussian) distribution parameterized by mean (:attr:`loc`) and
2653+
standard deviation (:attr:`scale`).
2654+
2655+
The probability density function (PDF) is defined as:
2656+
2657+
.. math::
2658+
f(x; \mu, \sigma) = \frac{1}{\sigma \sqrt{2\pi}}
2659+
\exp\!\left( -\frac{(x - \mu)^2}{2\sigma^2} \right)
2660+
2661+
where :math:`x \in \mathbb{R}`, :math:`\mu \in \mathbb{R}` is the mean,
2662+
and :math:`\sigma > 0` is the standard deviation.
2663+
2664+
:param loc: Mean of the distribution (:math:`\mu`).
2665+
:type loc: ArrayLike
2666+
:param scale: Standard deviation of the distribution (:math:`\sigma`).
2667+
:type scale: ArrayLike
2668+
:param validate_args: Whether to validate input constraints, defaults to None.
2669+
:type validate_args: bool, optional
2670+
"""
2671+
25162672
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
25172673
support = constraints.real
25182674
reparametrized_params = ["loc", "scale"]
@@ -2531,6 +2687,16 @@ def __init__(
25312687
)
25322688

25332689
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
2690+
r"""Generates samples via the reparameterization trick:
2691+
:math:`X = \mu + \sigma \epsilon`, where :math:`\epsilon \sim \mathcal{N}(0,1)`.
2692+
2693+
:param key: JAX PRNGKey for reproducibility.
2694+
:type key: jax.Array
2695+
:param sample_shape: The shape of the samples to be generated.
2696+
:type sample_shape: tuple[int, ...]
2697+
:return: Samples from the Normal distribution of shape ``sample_shape + batch_shape``.
2698+
:rtype: ArrayLike
2699+
"""
25342700
assert is_prng_key(key)
25352701
eps = random.normal(
25362702
key, shape=sample_shape + self.batch_shape + self.event_shape
@@ -2539,29 +2705,83 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik
25392705

25402706
@validate_sample
25412707
def log_prob(self, value: ArrayLike) -> ArrayLike:
2708+
r"""Calculates the log of the probability density function.
2709+
2710+
.. math::
2711+
\log f(x; \mu, \sigma) = -\frac{(x - \mu)^2}{2\sigma^2}
2712+
- \log(\sigma \sqrt{2\pi})
2713+
2714+
:param value: Values at which to evaluate the log density.
2715+
:type value: ArrayLike
2716+
:return: Log probability density.
2717+
:rtype: ArrayLike
2718+
"""
25422719
normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale)
25432720
value_scaled = (value - self.loc) / self.scale
25442721
return -0.5 * value_scaled**2 - normalize_term
25452722

25462723
def cdf(self, value: ArrayLike) -> ArrayLike:
2724+
r"""Cumulative distribution function.
2725+
2726+
.. math::
2727+
F(x; \mu, \sigma) = \Phi\!\left(\frac{x-\mu}{\sigma}\right)
2728+
2729+
where, :math:`\Phi` is the
2730+
`cumulative distribution function of standard normal distribution <https://en.wikipedia.org/wiki/Normal_distribution#Cumulative_distribution_function>`_.
2731+
Implementation uses :func:`jax.scipy.special.ndtr` for :math:`\Phi`.
2732+
2733+
:param value: Value to evaluate.
2734+
:type value: ArrayLike
2735+
"""
25472736
scaled = (value - self.loc) / self.scale
25482737
return ndtr(scaled)
25492738

25502739
def log_cdf(self, value: ArrayLike) -> ArrayLike:
2740+
r"""Log of the cumulative distribution function. Implementation
2741+
calls :func:`jax.scipy.stats.norm.logcdf`.
2742+
2743+
:param value: Value to evaluate.
2744+
:type value: ArrayLike
2745+
"""
25512746
return jax_norm.logcdf(value, loc=self.loc, scale=self.scale)
25522747

25532748
def icdf(self, q: ArrayLike) -> ArrayLike:
2749+
r"""Inverse cumulative distribution function (Quantile function).
2750+
2751+
.. math::
2752+
F^{-1}(q; \mu, \sigma) = \mu + \sigma\,\Phi^{-1}(q)
2753+
2754+
where, :math:`\mathrm{\Phi^{-1}}` is inverse
2755+
`cumulative distribution function of standard normal distribution <https://en.wikipedia.org/wiki/Normal_distribution#Cumulative_distribution_function>`_.
2756+
Implementation uses :func:`jax.scipy.special.ndtri` for :math:`\Phi^{-1}`.
2757+
2758+
:param q: Probability value in :math:`[0,1]`.
2759+
:type q: ArrayLike
2760+
"""
25542761
return self.loc + self.scale * ndtri(q)
25552762

25562763
@property
25572764
def mean(self) -> ArrayLike:
2765+
r"""Calculates the analytical mean.
2766+
2767+
.. math:: E[X] = \mu
2768+
"""
25582769
return jnp.broadcast_to(self.loc, self.batch_shape)
25592770

25602771
@property
25612772
def variance(self) -> ArrayLike:
2773+
r"""Calculates the analytical variance.
2774+
2775+
.. math:: \mathrm{Var}(X) = \sigma^2
2776+
"""
25622777
return jnp.broadcast_to(self.scale**2, self.batch_shape)
25632778

25642779
def entropy(self) -> ArrayLike:
2780+
r"""Entropy of the Normal distribution.
2781+
2782+
.. math::
2783+
H(X) = \frac{1}{2} \ln(2\pi e \sigma^2)
2784+
"""
25652785
return jnp.broadcast_to(
25662786
(jnp.log(2 * np.pi * self.scale**2) + 1) / 2, self.batch_shape
25672787
)

0 commit comments

Comments
 (0)