|
15 | 15 | from .types import PRNGKeyType |
16 | 16 | from .utils import merge_dictionaries |
17 | 17 |
|
18 | | -Identity = dx.Lambda(lambda x: x) |
| 18 | +Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x) |
19 | 19 |
|
20 | 20 |
|
21 | 21 | ################################ |
@@ -246,15 +246,16 @@ def prior_checks(priors: dict) -> dict: |
246 | 246 | """Run checks on th parameters' prior distributions. This checks that for Gaussian processes that are constructed with non-conjugate likelihoods, the prior distribution on the function's latent values is a unit Gaussian.""" |
247 | 247 | if "latent" in priors.keys(): |
248 | 248 | latent_prior = priors["latent"] |
249 | | - if latent_prior.name != "Normal": |
250 | | - warnings.warn( |
251 | | - f"A {latent_prior.name} distribution prior has been placed on" |
252 | | - " the latent function. It is strongly advised that a" |
253 | | - " unit-Gaussian prior is used." |
254 | | - ) |
| 249 | + if latent_prior is not None: |
| 250 | + if latent_prior.name != "Normal": |
| 251 | + warnings.warn( |
| 252 | + f"A {latent_prior.name} distribution prior has been placed on" |
| 253 | + " the latent function. It is strongly advised that a" |
| 254 | + " unit Gaussian prior is used." |
| 255 | + ) |
255 | 256 | else: |
256 | | - if not latent_prior: |
257 | | - priors["latent"] = dx.Normal(loc=0.0, scale=1.0) |
| 257 | + warnings.warn("Placing unit Gaussian prior on latent function.") |
| 258 | + priors["latent"] = dx.Normal(loc=0.0, scale=1.0) |
258 | 259 | else: |
259 | 260 | priors["latent"] = dx.Normal(loc=0.0, scale=1.0) |
260 | 261 |
|
@@ -284,6 +285,6 @@ def stop_grad(param: tp.Dict, trainable: tp.Dict): |
284 | 285 |
|
285 | 286 | def trainable_params(params: tp.Dict, trainables: tp.Dict) -> tp.Dict: |
286 | 287 | """Stop the gradients flowing through parameters whose trainable status is False""" |
287 | | - return jax.tree_map( |
| 288 | + return jax.tree_util.tree_map( |
288 | 289 | lambda param, trainable: stop_grad(param, trainable), params, trainables |
289 | 290 | ) |
0 commit comments