1010import jax .random as jr
1111from chex import dataclass
1212from jaxtyping import f64
13- from tensorflow_probability .substrates .jax import distributions as tfd
1413
1514from .config import get_defaults
1615from .types import PRNGKeyType
1716from .utils import merge_dictionaries
1817
19- Identity = dx .Lambda (lambda x : x )
18+ Identity = dx .Lambda (forward = lambda x : x , inverse = lambda x : x )
2019
2120
2221################################
@@ -163,11 +162,13 @@ def inverse(bijector):
163162
164163 bijectors = build_bijectors (params )
165164
166- constrainers = jax .tree_map (lambda _ : forward , deepcopy (params ))
167- unconstrainers = jax .tree_map (lambda _ : inverse , deepcopy (params ))
165+ constrainers = jax .tree_util . tree_map (lambda _ : forward , deepcopy (params ))
166+ unconstrainers = jax .tree_util . tree_map (lambda _ : inverse , deepcopy (params ))
168167
169- constrainers = jax .tree_map (lambda f , b : f (b ), constrainers , bijectors )
170- unconstrainers = jax .tree_map (lambda f , b : f (b ), unconstrainers , bijectors )
168+ constrainers = jax .tree_util .tree_map (lambda f , b : f (b ), constrainers , bijectors )
169+ unconstrainers = jax .tree_util .tree_map (
170+ lambda f , b : f (b ), unconstrainers , bijectors
171+ )
171172
172173 return constrainers , unconstrainers
173174
@@ -182,7 +183,9 @@ def transform(params: tp.Dict, transform_map: tp.Dict) -> tp.Dict:
182183 Returns:
183184 tp.Dict: A transformed parameter set.s The dictionary is equal in structure to the input params dictionary.
184185 """
185- return jax .tree_map (lambda param , trans : trans (param ), params , transform_map )
186+ return jax .tree_util .tree_map (
187+ lambda param , trans : trans (param ), params , transform_map
188+ )
186189
187190
188191################################
@@ -200,7 +203,7 @@ def copy_dict_structure(params: dict) -> dict:
200203 # Copy dictionary structure
201204 prior_container = deepcopy (params )
202205 # Set all values to zero
203- prior_container = jax .tree_map (lambda _ : None , prior_container )
206+ prior_container = jax .tree_util . tree_map (lambda _ : None , prior_container )
204207 return prior_container
205208
206209
@@ -243,23 +246,16 @@ def prior_checks(priors: dict) -> dict:
243246 """Run checks on th parameters' prior distributions. This checks that for Gaussian processes that are constructed with non-conjugate likelihoods, the prior distribution on the function's latent values is a unit Gaussian."""
244247 if "latent" in priors .keys ():
245248 latent_prior = priors ["latent" ]
246- if isinstance (latent_prior , dx .Distribution ) and latent_prior .name != "Normal" :
247- warnings .warn (
248- f"A { latent_prior .name } distribution prior has been placed on"
249- " the latent function. It is strongly advised that a"
250- " unit-Gaussian prior is used."
251- )
252- elif (
253- isinstance (latent_prior , tfd .Distribution ) and latent_prior .name != "Normal"
254- ):
255- warnings .warn (
256- f"A { latent_prior .name } distribution from Tensorflow Probability has been"
257- "placed on the latent function. We advise using a unit-Gaussian prior from"
258- " Distrax."
259- )
249+ if latent_prior is not None :
250+ if latent_prior .name != "Normal" :
251+ warnings .warn (
252+ f"A { latent_prior .name } distribution prior has been placed on"
253+ " the latent function. It is strongly advised that a"
254+ " unit Gaussian prior is used."
255+ )
260256 else :
261- if not latent_prior :
262- priors ["latent" ] = dx .Normal (loc = 0.0 , scale = 1.0 )
257+ warnings . warn ( "Placing unit Gaussian prior on latent function." )
258+ priors ["latent" ] = dx .Normal (loc = 0.0 , scale = 1.0 )
263259 else :
264260 priors ["latent" ] = dx .Normal (loc = 0.0 , scale = 1.0 )
265261
@@ -278,7 +274,7 @@ def build_trainables(params: tp.Dict) -> tp.Dict:
278274 # Copy dictionary structure
279275 prior_container = deepcopy (params )
280276 # Set all values to zero
281- prior_container = jax .tree_map (lambda _ : True , prior_container )
277+ prior_container = jax .tree_util . tree_map (lambda _ : True , prior_container )
282278 return prior_container
283279
284280
@@ -289,6 +285,6 @@ def stop_grad(param: tp.Dict, trainable: tp.Dict):
289285
290286def trainable_params (params : tp .Dict , trainables : tp .Dict ) -> tp .Dict :
291287 """Stop the gradients flowing through parameters whose trainable status is False"""
292- return jax .tree_map (
288+ return jax .tree_util . tree_map (
293289 lambda param , trainable : stop_grad (param , trainable ), params , trainables
294290 )
0 commit comments