2020from distrax ._src .distributions import distribution as base_distribution
2121from distrax ._src .utils import conversion
2222from distrax ._src .utils import math
23+ import jax
2324import jax .numpy as jnp
2425from tensorflow_probability .substrates import jax as tfp
2526
@@ -48,7 +49,8 @@ class Quantized(
4849 def __init__ (self ,
4950 distribution : DistributionLike ,
5051 low : Optional [Numeric ] = None ,
51- high : Optional [Numeric ] = None ):
52+ high : Optional [Numeric ] = None ,
53+ eps : Optional [Numeric ] = None ):
5254 """Initializes a Quantized distribution.
5355
5456 Args:
@@ -61,9 +63,13 @@ def __init__(self,
6163 floor(high)`. Its shape must broadcast with the shape of samples from
6264 `distribution` and must not result in additional batch dimensions after
6365 broadcasting.
66+ eps: An optional gap to enforce between "big" and "small". Useful for
67+ avoiding NANs in computing log_probs, when "big" and "small"
68+ are too close.
6469 """
6570 self ._dist : base_distribution .Distribution [Array , Tuple [
6671 int , ...], jnp .dtype ] = conversion .as_distribution (distribution )
72+ self ._eps = eps
6773 if self ._dist .event_shape :
6874 raise ValueError (f'The base distribution must be univariate, but its '
6975 f'`event_shape` is { self ._dist .event_shape } .' )
@@ -180,6 +186,10 @@ def log_prob(self, value: EventT) -> Array:
180186 # which happens to the right of the median of the distribution.
181187 big = jnp .where (log_sf < log_cdf , log_sf_m1 , log_cdf )
182188 small = jnp .where (log_sf < log_cdf , log_sf , log_cdf_m1 )
189+ if self ._eps is not None :
190+ # use stop_gradient to block updating in this case
191+ big = jnp .where (big - small > self ._eps , big ,
192+ jax .lax .stop_gradient (small ) + self ._eps )
183193 log_probs = math .log_expbig_minus_expsmall (big , small )
184194
185195 # Return -inf when evaluating on non-integer value.
0 commit comments