Skip to content

Commit 154f0d0

Browse files
committed
Fix typo
1 parent 899f6f7 commit 154f0d0

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

gpjax/parameters.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .types import PRNGKeyType
1616
from .utils import merge_dictionaries
1717

18-
Identity = dx.Lambda(lambda x: x)
18+
Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x)
1919

2020

2121
################################
@@ -246,15 +246,16 @@ def prior_checks(priors: dict) -> dict:
246246
"""Run checks on th parameters' prior distributions. This checks that for Gaussian processes that are constructed with non-conjugate likelihoods, the prior distribution on the function's latent values is a unit Gaussian."""
247247
if "latent" in priors.keys():
248248
latent_prior = priors["latent"]
249-
if latent_prior.name != "Normal":
250-
warnings.warn(
251-
f"A {latent_prior.name} distribution prior has been placed on"
252-
" the latent function. It is strongly advised that a"
253-
" unit-Gaussian prior is used."
254-
)
249+
if latent_prior is not None:
250+
if latent_prior.name != "Normal":
251+
warnings.warn(
252+
f"A {latent_prior.name} distribution prior has been placed on"
253+
" the latent function. It is strongly advised that a"
254+
" unit Gaussian prior is used."
255+
)
255256
else:
256-
if not latent_prior:
257-
priors["latent"] = dx.Normal(loc=0.0, scale=1.0)
257+
warnings.warn("Placing unit Gaussian prior on latent function.")
258+
priors["latent"] = dx.Normal(loc=0.0, scale=1.0)
258259
else:
259260
priors["latent"] = dx.Normal(loc=0.0, scale=1.0)
260261

@@ -284,6 +285,6 @@ def stop_grad(param: tp.Dict, trainable: tp.Dict):
284285

285286
def trainable_params(params: tp.Dict, trainables: tp.Dict) -> tp.Dict:
286287
"""Stop the gradients flowing through parameters whose trainable status is False"""
287-
return jax.tree_map(
288+
return jax.tree_util.tree_map(
288289
lambda param, trainable: stop_grad(param, trainable), params, trainables
289290
)

0 commit comments

Comments
 (0)