Skip to content

Commit b7d457d

Browse files
DistraxDevDistraxDev
authored andcommitted
Handle local measures in TransformedDistribution.
This change continues to set up the framework for tracking base measures and computing corrections on transformed densities. In `TransformedDistribution` we update `log_prob` to call a version of `experimental_local_measure` that keeps track of the base measure. We introduce a backwards-compatibility argument to control this rollout. PiperOrigin-RevId: 385616650
1 parent a1c5d43 commit b7d457d

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

distrax/_src/bijectors/tfp_compatible_bijector.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121
from distrax._src.utils import math
2222
import jax
2323
import jax.numpy as jnp
24+
from tensorflow_probability.python.experimental import tangent_spaces
2425
from tensorflow_probability.substrates import jax as tfp
2526

2627
tfb = tfp.bijectors
2728
tfd = tfp.distributions
2829

2930
Array = chex.Array
3031
Bijector = bijector.Bijector
32+
TangentSpace = tangent_spaces.TangentSpace
3133

3234

3335
def tfp_compatible_bijector(
@@ -175,4 +177,31 @@ def _check_shape(
175177
f"{event_shape} which has only {len(event_shape)} "
176178
f"dimensions instead.")
177179

180+
def experimental_compute_density_correction(
181+
self,
182+
x: Array,
183+
tangent_space: TangentSpace,
184+
backward_compat: bool = True,
185+
**kwargs):
186+
"""Density correction for this transform wrt the tangent space, at x.
187+
188+
See `tfb.bijector.Bijector.experimental_compute_density_correction`, and
189+
Radul and Alexeev, AISTATS 2021, “The Base Measure Problem and its
190+
Solution”, https://arxiv.org/abs/2010.09647.
191+
192+
Args:
193+
x: `float` or `double` `Array`.
194+
tangent_space: `TangentSpace` or one of its subclasses. The tangent to
195+
the support manifold at `x`.
196+
backward_compat: unused
197+
**kwargs: Optional keyword arguments forwarded to tangent space methods.
198+
199+
Returns:
200+
density_correction: `Array` representing the density correction---in log
201+
space---under the transformation that this Bijector denotes. Assumes
202+
the Bijector is dimension-preserving.
203+
"""
204+
del backward_compat
205+
return tangent_space.transform_dimension_preserving(x, self, **kwargs)
206+
178207
return TFPCompatibleBijector()

distrax/_src/distributions/tfp_compatible_distribution.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
# ==============================================================================
1515
"""Wrapper to adapt a Distrax distribution for use in TFP."""
1616

17-
from typing import Dict, Optional, Sequence, Union
17+
from typing import Dict, Optional, Sequence, Tuple, Union
1818

1919
import chex
2020
from distrax._src.distributions import distribution
2121
import jax.numpy as jnp
2222
import numpy as np
23+
from tensorflow_probability.python.experimental import tangent_spaces
2324
from tensorflow_probability.substrates import jax as tfp
2425

2526
tfd = tfp.distributions
@@ -29,6 +30,7 @@
2930
Distribution = distribution.Distribution
3031
IntLike = distribution.IntLike
3132
PRNGKey = chex.PRNGKey
33+
TangentSpace = tangent_spaces.TangentSpace
3234

3335

3436
def tfp_compatible_distribution(
@@ -136,4 +138,29 @@ def sample(self,
136138
sample_shape = tuple(sample_shape)
137139
return base_distribution.sample(sample_shape=sample_shape, seed=seed)
138140

141+
def experimental_local_measure(
142+
self,
143+
value: Array,
144+
backward_compat: bool = True,
145+
**unused_kwargs) -> Tuple[Array, TangentSpace]:
146+
"""Returns a log probability density together with a `TangentSpace`.
147+
148+
See `tfd.distribution.Distribution.experimental_local_measure`, and
149+
Radul and Alexeev, AISTATS 2021, “The Base Measure Problem and its
150+
Solution”, https://arxiv.org/abs/2010.09647.
151+
152+
Args:
153+
value: `float` or `double` `Array`.
154+
backward_compat: unused
155+
**unused_kwargs: unused
156+
157+
Returns:
158+
log_prob: see `log_prob`.
159+
tangent_space: `tangent_spaces.FullSpace()`, representing R^n with the
160+
standard basis.
161+
"""
162+
del backward_compat
163+
del unused_kwargs
164+
return self.log_prob(value), tangent_spaces.FullSpace()
165+
139166
return TFPCompatibleDistribution()

0 commit comments

Comments
 (0)