1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- """Shared ELBO optimization step for Gaussian VI variants (MFVI, FRVI)."""
15- from typing import Callable
14+ """Shared Gaussian VI optimization step for:
15+ * mean field variational inference (MFVI)
16+ * full rank variational inference (FRVI)"""
17+ from dataclasses import dataclass
18+ from typing import Callable , Union
1619
1720import jax
21+ import jax .numpy as jnp
22+ import jax .scipy as jsp
1823from optax import GradientTransformation , OptState
1924
2025
26+ @dataclass (frozen = True )
27+ class KL :
28+ """standard reverse-KL objective"""
29+
30+ pass
31+
32+
33+ @dataclass (frozen = True )
34+ class RenyiAlpha :
35+ """Rényi alpha objective.
36+
37+ Notes
38+ -----
39+ A smooth interpolation from the evidence lower-bound to the
40+ log (marginal) likelihood that is controlled by the value of alpha
41+ that parametrises the divergence.
42+ """
43+
44+ alpha : float
45+
46+
47+ Objective = Union [KL , RenyiAlpha ]
48+
49+
50+ def _objective_value_from_log_ratio (
51+ log_ratio : jax .Array ,
52+ objective : Objective ,
53+ ) -> jax .Array :
54+ """Returns a scalar loss to minimize from the given log-ratio array and
55+ supports two objective types.:
56+
57+ * KL: returns mean of the log-ratio, corresponding to KL divergence loss
58+ * RenyiAlpha: returns negative Monte Carlo Rényi variational bound.
59+ For alpha = 1.0 it recovers the reverse-KL objective.
60+ For other alpha values, it computes:
61+ (logsumexp((alpha - 1) * log_ratio) - log(N)) / (alpha - 1)
62+ where N is the number of samples.
63+
64+ Parameters
65+ ----------
66+ log_ratio: A JAX array of log-ratio values (log q - log p)
67+ objective: An instance of objective (KL or RenyiAlpha)
68+
69+ Returns
70+ -------
71+ A scalar JAX array representing the loss value to be minimized.
72+
73+ """
74+ if isinstance (objective , KL ):
75+ return jnp .mean (log_ratio )
76+
77+ if isinstance (objective , RenyiAlpha ):
78+ alpha = objective .alpha
79+
80+ # for alpha = 1.0 it recovers the reverse-KL objective.
81+ if alpha == 1.0 :
82+ return jnp .mean (log_ratio )
83+
84+ # negative Monte Carlo Renyi variational bound:
85+ # -L_hat_alpha = (1 / (alpha - 1)) * log mean(exp((alpha - 1) * (logq - logp)))
86+ scaled = (alpha - 1.0 ) * log_ratio
87+ return (jsp .special .logsumexp (scaled ) - jnp .log (log_ratio .shape [0 ])) / (
88+ alpha - 1.0
89+ )
90+
91+ raise TypeError (f"Unsupported objective type: { type (objective )!r} " )
92+
93+
2194def _elbo_step (
2295 rng_key ,
2396 parameters : tuple ,
@@ -27,13 +100,15 @@ def _elbo_step(
27100 sample_fn : Callable ,
28101 logq_fn : Callable ,
29102 num_samples : int ,
30- stl_estimator : bool ,
103+ objective : Objective = KL (),
104+ stl_estimator : bool = True ,
31105) -> tuple [tuple , OptState , float ]:
32- """Single ELBO optimization step shared by Gaussian VI variants .
106+ """Single Gaussian VI optimization step shared by MFVI and FRVI .
33107
34- Computes the KL divergence ``E_q[log q - log p]`` via Monte Carlo,
35- differentiates with respect to ``parameters``, and applies one optimizer
36- update.
108+ Single step of variational optimisation (ELBO or Renyi bound)
109+ shared by Gaussian VI variants. Computes a variational loss
110+ (KL or Renyi) via Monte Carlo, differentiates with respect to
111+ ``parameters``, and applies one optimizer update.
37112
38113 Parameters
39114 ----------
@@ -55,6 +130,8 @@ def _elbo_step(
55130 function of the current approximation given its parameters.
56131 num_samples
57132 Number of Monte Carlo samples used to estimate the ELBO.
133+ objective
134+ The variational objective (KL or Rényi). Defaults to KL.
58135 stl_estimator
59136 If ``True``, apply ``stop_gradient`` to the parameters used in
60137 ``logq_fn`` (stick-the-landing estimator). Gradients still flow
@@ -66,21 +143,29 @@ def _elbo_step(
66143 Updated variational parameters after one optimizer step.
67144 new_opt_state
68145 Updated optimizer state.
69- elbo
70- Current ELBO estimate (scalar).
146+ loss
147+ Current estimate of the variational loss (scalar).
71148
72149 """
73150
74- def kl_divergence_fn (parameters ):
151+ if stl_estimator and isinstance (objective , RenyiAlpha ) and objective .alpha != 1.0 :
152+ raise ValueError (
153+ "stl_estimator is currently only supported with KL() or "
154+ "RenyiAlpha(alpha=1.0). Use stl_estimator=False for "
155+ "RenyiAlpha(alpha != 1.0)."
156+ )
157+
158+ def objective_fn (parameters ):
75159 z = sample_fn (rng_key , parameters , num_samples )
76160 logq_parameters = (
77161 jax .lax .stop_gradient (parameters ) if stl_estimator else parameters
78162 )
79163 logq = jax .vmap (logq_fn (logq_parameters ))(z )
80164 logp = jax .vmap (logdensity_fn )(z )
81- return (logq - logp ).mean ()
165+ log_ratio = logq - logp
166+ return _objective_value_from_log_ratio (log_ratio , objective )
82167
83- elbo , elbo_grad = jax .value_and_grad (kl_divergence_fn )(parameters )
84- updates , new_opt_state = optimizer .update (elbo_grad , opt_state , parameters )
168+ objective_value , objective_grad = jax .value_and_grad (objective_fn )(parameters )
169+ updates , new_opt_state = optimizer .update (objective_grad , opt_state , parameters )
85170 new_parameters = jax .tree .map (lambda p , u : p + u , parameters , updates )
86- return new_parameters , new_opt_state , elbo
171+ return new_parameters , new_opt_state , objective_value
0 commit comments