@@ -334,6 +334,27 @@ def entropy(self) -> ArrayLike:
334334
335335
336336class Cauchy (Distribution ):
337+ r"""Cauchy distribution parameterized by location (:attr:`loc`) and
338+ scale (:attr:`scale`).
339+
340+ The probability density function (PDF) is defined as:
341+
342+ .. math::
343+ f(x; x_0, \gamma) = \frac{1}{\pi \gamma \left[1 +
344+ \left(\frac{x - x_0}{\gamma}\right)^2\right]}
345+
346+ where :math:`x \in \mathbb{R}`, :math:`x_0 \in \mathbb{R}` is the location,
347+ and :math:`\gamma > 0` is the scale. The Cauchy distribution has no finite
348+ mean or variance.
349+
350+ :param loc: Location parameter (:math:`x_0`).
351+ :type loc: ArrayLike
352+ :param scale: Scale parameter (:math:`\gamma`).
353+ :type scale: ArrayLike
354+ :param validate_args: Whether to validate input constraints, defaults to None.
355+ :type validate_args: bool, optional
356+ """
357+
337358 arg_constraints = {"loc" : constraints .real , "scale" : constraints .positive }
338359 support = constraints .real
339360 reparametrized_params = ["loc" , "scale" ]
@@ -352,12 +373,32 @@ def __init__(
352373 )
353374
354375 def sample (self , key : jax .Array , sample_shape : tuple [int , ...] = ()) -> ArrayLike :
376+ r"""Generates samples using the inverse CDF method via :func:`~jax.random.cauchy`.
377+
378+ :param key: JAX PRNGKey for reproducibility.
379+ :type key: jax.Array
380+ :param sample_shape: The shape of the samples to be generated.
381+ :type sample_shape: tuple[int, ...]
382+ :return: Samples from the Cauchy distribution of shape ``sample_shape + batch_shape``.
383+ :rtype: ArrayLike
384+ """
355385 assert is_prng_key (key )
356386 eps = random .cauchy (key , shape = sample_shape + self .batch_shape )
357387 return self .loc + eps * self .scale
358388
359389 @validate_sample
360390 def log_prob (self , value : ArrayLike ) -> ArrayLike :
391+ r"""Calculates the log of the probability density function.
392+
393+ .. math::
394+ \log f(x; x_0, \gamma) = -\log(\pi) - \log(\gamma)
395+ - \log\!\left[1 + \left(\frac{x - x_0}{\gamma}\right)^2\right]
396+
397+ :param value: Values at which to evaluate the log density.
398+ :type value: ArrayLike
399+ :return: Log probability density.
400+ :rtype: ArrayLike
401+ """
361402 return (
362403 - jnp .log (jnp .pi )
363404 - jnp .log (self .scale )
@@ -366,20 +407,51 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:
366407
367408 @property
368409 def mean (self ) -> ArrayLike :
410+ r"""The mean of the Cauchy distribution is undefined.
411+
412+ Returns ``NaN`` for all batch elements.
413+ """
369414 return jnp .full (self .batch_shape , jnp .nan )
370415
371416 @property
372417 def variance (self ) -> ArrayLike :
418+ r"""The variance of the Cauchy distribution is undefined.
419+
420+ Returns ``NaN`` for all batch elements.
421+ """
373422 return jnp .full (self .batch_shape , jnp .nan )
374423
375424 def cdf (self , value : ArrayLike ) -> ArrayLike :
425+ r"""Cumulative distribution function.
426+
427+ .. math::
428+ F(x; x_0, \gamma) = \frac{1}{\pi}\arctan\!\left(\frac{x - x_0}{\gamma}\right)
429+ + \frac{1}{2}
430+
431+ :param value: Value to evaluate.
432+ :type value: ArrayLike
433+ """
376434 scaled = (value - self .loc ) / self .scale
377435 return jnp .arctan (scaled ) / jnp .pi + 0.5
378436
379437 def icdf (self , q : ArrayLike ) -> ArrayLike :
438+ r"""Inverse cumulative distribution function (Quantile function).
439+
440+ .. math::
441+ F^{-1}(q; x_0, \gamma) = x_0 + \gamma \tan\!\left[\pi\!\left(q
442+ - \tfrac{1}{2}\right)\right]
443+
444+ :param q: Probability value in :math:`[0,1]`.
445+ :type q: ArrayLike
446+ """
380447 return self .loc + self .scale * jnp .tan (jnp .pi * (q - 0.5 ))
381448
382449 def entropy (self ) -> ArrayLike :
450+ r"""Entropy of the Cauchy distribution.
451+
452+ .. math::
453+ H(X) = \log(4\pi\gamma)
454+ """
383455 return jnp .broadcast_to (jnp .log (4 * np .pi * self .scale ), self .batch_shape )
384456
385457
@@ -589,6 +661,21 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:
589661
590662
591663class Exponential (Distribution ):
664+ r"""Exponential distribution parameterized by rate (:attr:`rate`).
665+
666+ The probability density function (PDF) is defined as:
667+
668+ .. math::
669+ f(x; \lambda) = \lambda e^{-\lambda x}
670+
671+ where :math:`x \geq 0` and :math:`\lambda > 0` is the rate parameter.
672+
673+ :param rate: Rate parameter (:math:`\lambda`), the inverse of the mean.
674+ :type rate: ArrayLike
675+ :param validate_args: Whether to validate input constraints, defaults to None.
676+ :type validate_args: bool, optional
677+ """
678+
592679 reparametrized_params = ["rate" ]
593680 arg_constraints = {"rate" : constraints .positive }
594681 support = constraints .positive
@@ -605,30 +692,79 @@ def __init__(
605692 )
606693
607694 def sample (self , key : jax .Array , sample_shape : tuple [int , ...] = ()) -> ArrayLike :
695+ r"""Generates samples by scaling standard exponential draws by the
696+ inverse rate: :math:`X = E / \lambda`, where :math:`E \sim \mathrm{Exp}(1)`.
697+
698+ :param key: JAX PRNGKey for reproducibility.
699+ :type key: jax.Array
700+ :param sample_shape: The shape of the samples to be generated.
701+ :type sample_shape: tuple[int, ...]
702+ :return: Samples from the Exponential distribution of shape ``sample_shape + batch_shape``.
703+ :rtype: ArrayLike
704+ """
608705 assert is_prng_key (key )
609706 return (
610707 random .exponential (key , shape = sample_shape + self .batch_shape ) / self .rate
611708 )
612709
613710 @validate_sample
614711 def log_prob (self , value : ArrayLike ) -> ArrayLike :
712+ r"""Calculates the log of the probability density function.
713+
714+ .. math::
715+ \log f(x; \lambda) = \log \lambda - \lambda x
716+
717+ :param value: Values at which to evaluate the log density.
718+ :type value: ArrayLike
719+ :return: Log probability density.
720+ :rtype: ArrayLike
721+ """
615722 return jnp .log (self .rate ) - self .rate * value
616723
617724 @property
618725 def mean (self ) -> ArrayLike :
726+ r"""Calculates the analytical mean.
727+
728+ .. math:: E[X] = \frac{1}{\lambda}
729+ """
619730 return jnp .reciprocal (self .rate )
620731
621732 @property
622733 def variance (self ) -> ArrayLike :
734+ r"""Calculates the analytical variance.
735+
736+ .. math:: \mathrm{Var}(X) = \frac{1}{\lambda^2}
737+ """
623738 return jnp .reciprocal (self .rate ** 2 )
624739
625740 def cdf (self , value : ArrayLike ) -> ArrayLike :
741+ r"""Cumulative distribution function.
742+
743+ .. math::
744+ F(x; \lambda) = 1 - e^{-\lambda x}
745+
746+ :param value: Value to evaluate.
747+ :type value: ArrayLike
748+ """
626749 return - jnp .expm1 (- self .rate * value )
627750
628751 def icdf (self , q : ArrayLike ) -> ArrayLike :
752+ r"""Inverse cumulative distribution function (Quantile function).
753+
754+ .. math::
755+ F^{-1}(q; \lambda) = -\frac{\ln(1 - q)}{\lambda}
756+
757+ :param q: Probability value in :math:`[0,1]`.
758+ :type q: ArrayLike
759+ """
629760 return - jnp .log1p (- q ) / self .rate
630761
631762 def entropy (self ) -> ArrayLike :
763+ r"""Entropy of the Exponential distribution.
764+
765+ .. math::
766+ H(X) = 1 - \ln \lambda
767+ """
632768 return 1 - jnp .log (self .rate )
633769
634770
@@ -2513,6 +2649,26 @@ def infer_shapes(loc, cov_factor, cov_diag):
25132649
25142650
25152651class Normal (Distribution ):
2652+ r"""Normal (Gaussian) distribution parameterized by mean (:attr:`loc`) and
2653+ standard deviation (:attr:`scale`).
2654+
2655+ The probability density function (PDF) is defined as:
2656+
2657+ .. math::
2658+ f(x; \mu, \sigma) = \frac{1}{\sigma \sqrt{2\pi}}
2659+ \exp\!\left( -\frac{(x - \mu)^2}{2\sigma^2} \right)
2660+
2661+ where :math:`x \in \mathbb{R}`, :math:`\mu \in \mathbb{R}` is the mean,
2662+ and :math:`\sigma > 0` is the standard deviation.
2663+
2664+ :param loc: Mean of the distribution (:math:`\mu`).
2665+ :type loc: ArrayLike
2666+ :param scale: Standard deviation of the distribution (:math:`\sigma`).
2667+ :type scale: ArrayLike
2668+ :param validate_args: Whether to validate input constraints, defaults to None.
2669+ :type validate_args: bool, optional
2670+ """
2671+
25162672 arg_constraints = {"loc" : constraints .real , "scale" : constraints .positive }
25172673 support = constraints .real
25182674 reparametrized_params = ["loc" , "scale" ]
@@ -2531,6 +2687,16 @@ def __init__(
25312687 )
25322688
25332689 def sample (self , key : jax .Array , sample_shape : tuple [int , ...] = ()) -> ArrayLike :
2690+ r"""Generates samples via the reparameterization trick:
2691+ :math:`X = \mu + \sigma \epsilon`, where :math:`\epsilon \sim \mathcal{N}(0,1)`.
2692+
2693+ :param key: JAX PRNGKey for reproducibility.
2694+ :type key: jax.Array
2695+ :param sample_shape: The shape of the samples to be generated.
2696+ :type sample_shape: tuple[int, ...]
2697+ :return: Samples from the Normal distribution of shape ``sample_shape + batch_shape``.
2698+ :rtype: ArrayLike
2699+ """
25342700 assert is_prng_key (key )
25352701 eps = random .normal (
25362702 key , shape = sample_shape + self .batch_shape + self .event_shape
@@ -2539,29 +2705,72 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik
25392705
25402706 @validate_sample
25412707 def log_prob (self , value : ArrayLike ) -> ArrayLike :
2708+ r"""Calculates the log of the probability density function.
2709+
2710+ .. math::
2711+ \log f(x; \mu, \sigma) = -\frac{(x - \mu)^2}{2\sigma^2}
2712+ - \log(\sigma \sqrt{2\pi})
2713+
2714+ :param value: Values at which to evaluate the log density.
2715+ :type value: ArrayLike
2716+ :return: Log probability density.
2717+ :rtype: ArrayLike
2718+ """
25422719 normalize_term = jnp .log (jnp .sqrt (2 * jnp .pi ) * self .scale )
25432720 value_scaled = (value - self .loc ) / self .scale
25442721 return - 0.5 * value_scaled ** 2 - normalize_term
25452722
25462723 def cdf (self , value : ArrayLike ) -> ArrayLike :
2724+ r"""Cumulative distribution function.
2725+
2726+ .. math::
2727+ F(x; \mu, \sigma) = \frac{1}{2}\left[1 + \mathrm{erf}\!\left(
2728+ \frac{x - \mu}{\sigma\sqrt{2}}\right)\right]
2729+
2730+ :param value: Value to evaluate.
2731+ :type value: ArrayLike
2732+ """
25472733 scaled = (value - self .loc ) / self .scale
25482734 return ndtr (scaled )
25492735
25502736 def log_cdf (self , value : ArrayLike ) -> ArrayLike :
2737+ r"""Log of the cumulative distribution function.
2738+
2739+ :param value: Value to evaluate.
2740+ :type value: ArrayLike
2741+ """
25512742 return jax_norm .logcdf (value , loc = self .loc , scale = self .scale )
25522743
25532744 def icdf (self , q : ArrayLike ) -> ArrayLike :
2745+ r"""Inverse cumulative distribution function (Quantile function).
2746+
2747+ :param q: Probability value in :math:`[0,1]`.
2748+ :type q: ArrayLike
2749+ """
25542750 return self .loc + self .scale * ndtri (q )
25552751
25562752 @property
25572753 def mean (self ) -> ArrayLike :
2754+ r"""Calculates the analytical mean.
2755+
2756+ .. math:: E[X] = \mu
2757+ """
25582758 return jnp .broadcast_to (self .loc , self .batch_shape )
25592759
25602760 @property
25612761 def variance (self ) -> ArrayLike :
2762+ r"""Calculates the analytical variance.
2763+
2764+ .. math:: \mathrm{Var}(X) = \sigma^2
2765+ """
25622766 return jnp .broadcast_to (self .scale ** 2 , self .batch_shape )
25632767
25642768 def entropy (self ) -> ArrayLike :
2769+ r"""Entropy of the Normal distribution.
2770+
2771+ .. math::
2772+ H(X) = \frac{1}{2} \ln(2\pi e \sigma^2)
2773+ """
25652774 return jnp .broadcast_to (
25662775 (jnp .log (2 * np .pi * self .scale ** 2 ) + 1 ) / 2 , self .batch_shape
25672776 )
0 commit comments