Skip to content

Commit ffca7c8

Browse files
Add Renyi alpha divergence to variational inference and tests (#849)
* Add Renyi alpha divergence to variational inference and tests * Add coverage for Gaussian VI objective helper --------- Co-authored-by: Junpeng Lao <junpenglao@gmail.com>
1 parent 08f5a42 commit ffca7c8

5 files changed

Lines changed: 305 additions & 24 deletions

File tree

blackjax/vi/_gaussian_vi.py

Lines changed: 99 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,86 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Shared ELBO optimization step for Gaussian VI variants (MFVI, FRVI)."""
15-
from typing import Callable
14+
"""Shared Gaussian VI optimization step for:
15+
* mean field variational inference (MFVI)
16+
* full rank variational inference (FRVI)"""
17+
from dataclasses import dataclass
18+
from typing import Callable, Union
1619

1720
import jax
21+
import jax.numpy as jnp
22+
import jax.scipy as jsp
1823
from optax import GradientTransformation, OptState
1924

2025

26+
@dataclass(frozen=True)
27+
class KL:
28+
"""standard reverse-KL objective"""
29+
30+
pass
31+
32+
33+
@dataclass(frozen=True)
34+
class RenyiAlpha:
35+
"""Rényi alpha objective.
36+
37+
Notes
38+
-----
39+
A smooth interpolation from the evidence lower-bound to the
40+
log (marginal) likelihood that is controlled by the value of alpha
41+
that parametrises the divergence.
42+
"""
43+
44+
alpha: float
45+
46+
47+
Objective = Union[KL, RenyiAlpha]
48+
49+
50+
def _objective_value_from_log_ratio(
51+
log_ratio: jax.Array,
52+
objective: Objective,
53+
) -> jax.Array:
54+
"""Returns a scalar loss to minimize from the given log-ratio array and
55+
supports two objective types.:
56+
57+
* KL: returns mean of the log-ratio, corresponding to KL divergence loss
58+
* RenyiAlpha: returns negative Monte Carlo Rényi variational bound.
59+
For alpha = 1.0 it recovers the reverse-KL objective.
60+
For other alpha values, it computes:
61+
(logsumexp((alpha - 1) * log_ratio) - log(N)) / (alpha - 1)
62+
where N is the number of samples.
63+
64+
Parameters
65+
----------
66+
log_ratio: A JAX array of log-ratio values (log q - log p)
67+
objective: An instance of objective (KL or RenyiAlpha)
68+
69+
Returns
70+
-------
71+
A scalar JAX array representing the loss value to be minimized.
72+
73+
"""
74+
if isinstance(objective, KL):
75+
return jnp.mean(log_ratio)
76+
77+
if isinstance(objective, RenyiAlpha):
78+
alpha = objective.alpha
79+
80+
# for alpha = 1.0 it recovers the reverse-KL objective.
81+
if alpha == 1.0:
82+
return jnp.mean(log_ratio)
83+
84+
# negative Monte Carlo Renyi variational bound:
85+
# -L_hat_alpha = (1 / (alpha - 1)) * log mean(exp((alpha - 1) * (logq - logp)))
86+
scaled = (alpha - 1.0) * log_ratio
87+
return (jsp.special.logsumexp(scaled) - jnp.log(log_ratio.shape[0])) / (
88+
alpha - 1.0
89+
)
90+
91+
raise TypeError(f"Unsupported objective type: {type(objective)!r}")
92+
93+
2194
def _elbo_step(
2295
rng_key,
2396
parameters: tuple,
@@ -27,13 +100,15 @@ def _elbo_step(
27100
sample_fn: Callable,
28101
logq_fn: Callable,
29102
num_samples: int,
30-
stl_estimator: bool,
103+
objective: Objective = KL(),
104+
stl_estimator: bool = True,
31105
) -> tuple[tuple, OptState, float]:
32-
"""Single ELBO optimization step shared by Gaussian VI variants.
106+
"""Single Gaussian VI optimization step shared by MFVI and FRVI.
33107
34-
Computes the KL divergence ``E_q[log q - log p]`` via Monte Carlo,
35-
differentiates with respect to ``parameters``, and applies one optimizer
36-
update.
108+
Single step of variational optimisation (ELBO or Renyi bound)
109+
shared by Gaussian VI variants. Computes a variational loss
110+
(KL or Renyi) via Monte Carlo, differentiates with respect to
111+
``parameters``, and applies one optimizer update.
37112
38113
Parameters
39114
----------
@@ -55,6 +130,8 @@ def _elbo_step(
55130
function of the current approximation given its parameters.
56131
num_samples
57132
Number of Monte Carlo samples used to estimate the ELBO.
133+
objective
134+
The variational objective (KL or Rényi). Defaults to KL.
58135
stl_estimator
59136
If ``True``, apply ``stop_gradient`` to the parameters used in
60137
``logq_fn`` (stick-the-landing estimator). Gradients still flow
@@ -66,21 +143,29 @@ def _elbo_step(
66143
Updated variational parameters after one optimizer step.
67144
new_opt_state
68145
Updated optimizer state.
69-
elbo
70-
Current ELBO estimate (scalar).
146+
loss
147+
Current estimate of the variational loss (scalar).
71148
72149
"""
73150

74-
def kl_divergence_fn(parameters):
151+
if stl_estimator and isinstance(objective, RenyiAlpha) and objective.alpha != 1.0:
152+
raise ValueError(
153+
"stl_estimator is currently only supported with KL() or "
154+
"RenyiAlpha(alpha=1.0). Use stl_estimator=False for "
155+
"RenyiAlpha(alpha != 1.0)."
156+
)
157+
158+
def objective_fn(parameters):
75159
z = sample_fn(rng_key, parameters, num_samples)
76160
logq_parameters = (
77161
jax.lax.stop_gradient(parameters) if stl_estimator else parameters
78162
)
79163
logq = jax.vmap(logq_fn(logq_parameters))(z)
80164
logp = jax.vmap(logdensity_fn)(z)
81-
return (logq - logp).mean()
165+
log_ratio = logq - logp
166+
return _objective_value_from_log_ratio(log_ratio, objective)
82167

83-
elbo, elbo_grad = jax.value_and_grad(kl_divergence_fn)(parameters)
84-
updates, new_opt_state = optimizer.update(elbo_grad, opt_state, parameters)
168+
objective_value, objective_grad = jax.value_and_grad(objective_fn)(parameters)
169+
updates, new_opt_state = optimizer.update(objective_grad, opt_state, parameters)
85170
new_parameters = jax.tree.map(lambda p, u: p + u, parameters, updates)
86-
return new_parameters, new_opt_state, elbo
171+
return new_parameters, new_opt_state, objective_value

blackjax/vi/fullrank_vi.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121

2222
from blackjax.base import VIAlgorithm
2323
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
24-
from blackjax.vi._gaussian_vi import _elbo_step
24+
from blackjax.vi._gaussian_vi import KL, Objective, RenyiAlpha, _elbo_step
2525

2626
__all__ = [
27+
"KL",
28+
"RenyiAlpha",
2729
"FRVIState",
2830
"FRVIInfo",
2931
"sample",
@@ -88,6 +90,7 @@ def step(
8890
logdensity_fn: Callable,
8991
optimizer: GradientTransformation,
9092
num_samples: int = 5,
93+
objective: Objective = KL(),
9194
stl_estimator: bool = True,
9295
) -> tuple[FRVIState, FRVIInfo]:
9396
"""Approximate the target density using the full-rank Gaussian approximation.
@@ -106,6 +109,9 @@ def step(
106109
The number of samples that are taken from the approximation
107110
at each step to compute the Kullback-Leibler divergence between
108111
the approximation and the target log-density.
112+
objective:
113+
The variational objective to minimize. `KL()` by default or
114+
`RenyiAlpha(alpha)`. For alpha = 1, Renyi reduces to KL.
109115
stl_estimator
110116
Whether to use the stick-the-landing (STL) gradient estimator
111117
:cite:p:`roeder2017sticking`. Reduces gradient variance by removing
@@ -137,7 +143,8 @@ def logq_fn(parameters):
137143
sample_fn,
138144
logq_fn,
139145
num_samples,
140-
stl_estimator,
146+
objective=objective,
147+
stl_estimator=stl_estimator,
141148
)
142149
new_state = FRVIState(new_parameters[0], new_parameters[1], new_opt_state)
143150
return new_state, FRVIInfo(elbo)
@@ -168,6 +175,8 @@ def as_top_level_api(
168175
logdensity_fn: Callable,
169176
optimizer: GradientTransformation,
170177
num_samples: int = 100,
178+
objective: Objective = KL(),
179+
stl_estimator: bool = True,
171180
):
172181
"""High-level implementation of Full-Rank Variational Inference.
173182
@@ -180,6 +189,12 @@ def as_top_level_api(
180189
Optax optimizer to use to optimize the ELBO.
181190
num_samples
182191
Number of samples to take at each step to optimize the ELBO.
192+
objective
193+
The variational objective to minimize. `KL()` by default or
194+
`RenyiAlpha(alpha)`. For alpha = 1, Renyi reduces to KL.
195+
stl_estimator
196+
Whether to use STL gradient estimator.
197+
Only supported when `objective` is `KL()` or `RenyiAlpha(alpha=1.0)`.
183198
184199
Returns
185200
-------
@@ -191,7 +206,15 @@ def init_fn(position: ArrayLikeTree):
191206
return init(position, optimizer)
192207

193208
def step_fn(rng_key: PRNGKey, state: FRVIState) -> tuple[FRVIState, FRVIInfo]:
194-
return step(rng_key, state, logdensity_fn, optimizer, num_samples)
209+
return step(
210+
rng_key,
211+
state,
212+
logdensity_fn,
213+
optimizer,
214+
num_samples,
215+
objective=objective,
216+
stl_estimator=stl_estimator,
217+
)
195218

196219
def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int):
197220
return sample(rng_key, state, num_samples)

blackjax/vi/meanfield_vi.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020

2121
from blackjax.base import VIAlgorithm
2222
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
23-
from blackjax.vi._gaussian_vi import _elbo_step
23+
from blackjax.vi._gaussian_vi import KL, Objective, RenyiAlpha, _elbo_step
2424

2525
__all__ = [
26+
"KL",
27+
"RenyiAlpha",
2628
"MFVIState",
2729
"MFVIInfo",
2830
"sample",
@@ -74,6 +76,7 @@ def step(
7476
logdensity_fn: Callable,
7577
optimizer: GradientTransformation,
7678
num_samples: int = 5,
79+
objective: Objective = KL(),
7780
stl_estimator: bool = True,
7881
) -> tuple[MFVIState, MFVIInfo]:
7982
"""Approximate the target density using the mean-field approximation.
@@ -92,6 +95,9 @@ def step(
9295
The number of samples that are taken from the approximation
9396
at each step to compute the Kullback-Leibler divergence between
9497
the approximation and the target log-density.
98+
objective
99+
The variational objective to minimize. `KL()` by default or
100+
`RenyiAlpha(alpha)`. For alpha = 1, Renyi reduces to KL.
95101
stl_estimator
96102
Whether to use the stick-the-landing (STL) gradient estimator
97103
:cite:p:`roeder2017sticking`. The STL estimator has lower gradient
@@ -120,7 +126,8 @@ def logq_fn(parameters):
120126
sample_fn,
121127
logq_fn,
122128
num_samples,
123-
stl_estimator,
129+
objective=objective,
130+
stl_estimator=stl_estimator,
124131
)
125132
new_state = MFVIState(new_parameters[0], new_parameters[1], new_opt_state)
126133
return new_state, MFVIInfo(elbo)
@@ -140,7 +147,7 @@ def sample(rng_key: PRNGKey, state: MFVIState, num_samples: int = 1):
140147
141148
Returns
142149
-------
143-
A PyTree of samples with leading dimension ``num_samples``.
150+
A PyTree of samples with leading dimension ``num_samples``
144151
"""
145152
return _sample(rng_key, state.mu, state.rho, num_samples)
146153

@@ -149,18 +156,26 @@ def as_top_level_api(
149156
logdensity_fn: Callable,
150157
optimizer: GradientTransformation,
151158
num_samples: int = 100,
159+
objective: Objective = KL(),
160+
stl_estimator: bool = True,
152161
):
153-
"""High-level implementation of Mean-Field Variational Inference.
162+
"""High-level implementation of Mean-Field Variational Inference
154163
155-
Parameters
164+
Parameters
156165
----------
157166
logdensity_fn
158167
A function that represents the log-density function associated with
159168
the distribution we want to sample from.
160169
optimizer
161-
Optax optimizer to use to optimize the ELBO.
170+
Optax optimizer to use to optimize the variational objective.
162171
num_samples
163172
Number of samples to take at each step to optimize the ELBO.
173+
objective
174+
The variational objective to minimize. `KL()` by default or
175+
`RenyiAlpha(alpha)`. For a = 1, Renyi reduces to KL.
176+
stl_estimator
177+
Whether to use the STL gradient estimator.
178+
Only supported when `objective` is `KL()` or `RenyiAlpha(alpha=1.0)`.
164179
165180
Returns
166181
-------
@@ -172,7 +187,15 @@ def init_fn(position: ArrayLikeTree):
172187
return init(position, optimizer)
173188

174189
def step_fn(rng_key: PRNGKey, state: MFVIState) -> tuple[MFVIState, MFVIInfo]:
175-
return step(rng_key, state, logdensity_fn, optimizer, num_samples)
190+
return step(
191+
rng_key,
192+
state,
193+
logdensity_fn,
194+
optimizer,
195+
num_samples,
196+
objective=objective,
197+
stl_estimator=stl_estimator,
198+
)
176199

177200
def sample_fn(rng_key: PRNGKey, state: MFVIState, num_samples: int):
178201
return sample(rng_key, state, num_samples)

0 commit comments

Comments
 (0)