Skip to content

Commit 93c54a8

Browse files
DistraxDevDistraxDev
authored andcommitted
Add an optional epsilon to avoid NANs when big and small are too close in computing log_prob.
PiperOrigin-RevId: 558787502
1 parent 09c0ce1 commit 93c54a8

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

distrax/_src/distributions/quantized.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from distrax._src.distributions import distribution as base_distribution
2121
from distrax._src.utils import conversion
2222
from distrax._src.utils import math
23+
import jax
2324
import jax.numpy as jnp
2425
from tensorflow_probability.substrates import jax as tfp
2526

@@ -48,7 +49,8 @@ class Quantized(
4849
def __init__(self,
4950
distribution: DistributionLike,
5051
low: Optional[Numeric] = None,
51-
high: Optional[Numeric] = None):
52+
high: Optional[Numeric] = None,
53+
eps: Optional[Numeric] = None):
5254
"""Initializes a Quantized distribution.
5355
5456
Args:
@@ -61,9 +63,13 @@ def __init__(self,
6163
floor(high)`. Its shape must broadcast with the shape of samples from
6264
`distribution` and must not result in additional batch dimensions after
6365
broadcasting.
66+
eps: An optional gap to enforce between "big" and "small". Useful for
67+
avoiding NANs in computing log_probs, when "big" and "small"
68+
are too close.
6469
"""
6570
self._dist: base_distribution.Distribution[Array, Tuple[
6671
int, ...], jnp.dtype] = conversion.as_distribution(distribution)
72+
self._eps = eps
6773
if self._dist.event_shape:
6874
raise ValueError(f'The base distribution must be univariate, but its '
6975
f'`event_shape` is {self._dist.event_shape}.')
@@ -180,6 +186,10 @@ def log_prob(self, value: EventT) -> Array:
180186
# which happens to the right of the median of the distribution.
181187
big = jnp.where(log_sf < log_cdf, log_sf_m1, log_cdf)
182188
small = jnp.where(log_sf < log_cdf, log_sf, log_cdf_m1)
189+
if self._eps is not None:
190+
# use stop_gradient to block updating in this case
191+
big = jnp.where(big - small > self._eps, big,
192+
jax.lax.stop_gradient(small) + self._eps)
183193
log_probs = math.log_expbig_minus_expsmall(big, small)
184194

185195
# Return -inf when evaluating on non-integer value.

0 commit comments

Comments
 (0)