Skip to content

Commit 4fca81e

Browse files
authored
Merge pull request #103 from thomaspinder/fix_deps
Relax Jax version
2 parents 2f00b11 + 2e5be70 commit 4fca81e

File tree

4 files changed

+54
-36
lines changed

4 files changed

+54
-36
lines changed

gpjax/config.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,31 @@
55

66
__config = None
77

8-
Identity = dx.Lambda(lambda x: x)
9-
Softplus = dx.Lambda(lambda x: jnp.log(1.0 + jnp.exp(x)))
8+
Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x)
9+
Softplus = dx.Lambda(
10+
forward=lambda x: jnp.log(1 + jnp.exp(x)),
11+
inverse=lambda x: jnp.log(jnp.exp(x) - 1.0),
12+
)
13+
14+
# class Softplus(dx.Bijector):
15+
# def __init__(self):
16+
# super().__init__(event_ndims_in=0)
17+
18+
# def forward_and_log_det(self, x):
19+
# softplus = lambda xx: jnp.log(1 + jnp.exp(xx))
20+
# y = softplus(x)
21+
# logdet = softplus(-x)
22+
# return y, logdet
23+
24+
# def inverse_and_log_det(self, y):
25+
# """
26+
# Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1]
27+
# ==> dX/dY = exp{Y} / (exp{Y} - 1)
28+
# = 1 / (1 - exp{-Y})
29+
# """
30+
# x = jnp.log(jnp.exp(y) - 1.0)
31+
# logdet = 1 / (1 - jnp.exp(-y))
32+
# return x, logdet
1033

1134

1235
def get_defaults() -> ConfigDict:

gpjax/parameters.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@
1010
import jax.random as jr
1111
from chex import dataclass
1212
from jaxtyping import f64
13-
from tensorflow_probability.substrates.jax import distributions as tfd
1413

1514
from .config import get_defaults
1615
from .types import PRNGKeyType
1716
from .utils import merge_dictionaries
1817

19-
Identity = dx.Lambda(lambda x: x)
18+
Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x)
2019

2120

2221
################################
@@ -163,11 +162,13 @@ def inverse(bijector):
163162

164163
bijectors = build_bijectors(params)
165164

166-
constrainers = jax.tree_map(lambda _: forward, deepcopy(params))
167-
unconstrainers = jax.tree_map(lambda _: inverse, deepcopy(params))
165+
constrainers = jax.tree_util.tree_map(lambda _: forward, deepcopy(params))
166+
unconstrainers = jax.tree_util.tree_map(lambda _: inverse, deepcopy(params))
168167

169-
constrainers = jax.tree_map(lambda f, b: f(b), constrainers, bijectors)
170-
unconstrainers = jax.tree_map(lambda f, b: f(b), unconstrainers, bijectors)
168+
constrainers = jax.tree_util.tree_map(lambda f, b: f(b), constrainers, bijectors)
169+
unconstrainers = jax.tree_util.tree_map(
170+
lambda f, b: f(b), unconstrainers, bijectors
171+
)
171172

172173
return constrainers, unconstrainers
173174

@@ -182,7 +183,9 @@ def transform(params: tp.Dict, transform_map: tp.Dict) -> tp.Dict:
182183
Returns:
183184
tp.Dict: A transformed parameter set.s The dictionary is equal in structure to the input params dictionary.
184185
"""
185-
return jax.tree_map(lambda param, trans: trans(param), params, transform_map)
186+
return jax.tree_util.tree_map(
187+
lambda param, trans: trans(param), params, transform_map
188+
)
186189

187190

188191
################################
@@ -200,7 +203,7 @@ def copy_dict_structure(params: dict) -> dict:
200203
# Copy dictionary structure
201204
prior_container = deepcopy(params)
202205
# Set all values to zero
203-
prior_container = jax.tree_map(lambda _: None, prior_container)
206+
prior_container = jax.tree_util.tree_map(lambda _: None, prior_container)
204207
return prior_container
205208

206209

@@ -243,23 +246,16 @@ def prior_checks(priors: dict) -> dict:
243246
"""Run checks on th parameters' prior distributions. This checks that for Gaussian processes that are constructed with non-conjugate likelihoods, the prior distribution on the function's latent values is a unit Gaussian."""
244247
if "latent" in priors.keys():
245248
latent_prior = priors["latent"]
246-
if isinstance(latent_prior, dx.Distribution) and latent_prior.name != "Normal":
247-
warnings.warn(
248-
f"A {latent_prior.name} distribution prior has been placed on"
249-
" the latent function. It is strongly advised that a"
250-
" unit-Gaussian prior is used."
251-
)
252-
elif (
253-
isinstance(latent_prior, tfd.Distribution) and latent_prior.name != "Normal"
254-
):
255-
warnings.warn(
256-
f"A {latent_prior.name} distribution from Tensorflow Probability has been"
257-
"placed on the latent function. We advise using a unit-Gaussian prior from"
258-
" Distrax."
259-
)
249+
if latent_prior is not None:
250+
if latent_prior.name != "Normal":
251+
warnings.warn(
252+
f"A {latent_prior.name} distribution prior has been placed on"
253+
" the latent function. It is strongly advised that a"
254+
" unit Gaussian prior is used."
255+
)
260256
else:
261-
if not latent_prior:
262-
priors["latent"] = dx.Normal(loc=0.0, scale=1.0)
257+
warnings.warn("Placing unit Gaussian prior on latent function.")
258+
priors["latent"] = dx.Normal(loc=0.0, scale=1.0)
263259
else:
264260
priors["latent"] = dx.Normal(loc=0.0, scale=1.0)
265261

@@ -278,7 +274,7 @@ def build_trainables(params: tp.Dict) -> tp.Dict:
278274
# Copy dictionary structure
279275
prior_container = deepcopy(params)
280276
# Set all values to zero
281-
prior_container = jax.tree_map(lambda _: True, prior_container)
277+
prior_container = jax.tree_util.tree_map(lambda _: True, prior_container)
282278
return prior_container
283279

284280

@@ -289,6 +285,6 @@ def stop_grad(param: tp.Dict, trainable: tp.Dict):
289285

290286
def trainable_params(params: tp.Dict, trainables: tp.Dict) -> tp.Dict:
291287
"""Stop the gradients flowing through parameters whose trainable status is False"""
292-
return jax.tree_map(
288+
return jax.tree_util.tree_map(
293289
lambda param, trainable: stop_grad(param, trainable), params, trainables
294290
)

setup.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@ def parse_requirements_file(filename):
1212

1313

1414
REQUIRES = [
15-
"jax==0.3.5",
16-
"jaxlib==0.3.5",
17-
"optax>=0.1.0",
18-
"chex==0.1.3",
15+
"jax>=0.1.67",
16+
"jaxlib>=0.1.47",
17+
"optax",
18+
"chex",
1919
"distrax>=0.1.2",
20-
"tensorflow-probability==0.16.0",
20+
"tensorflow-probability>=0.16.0",
2121
"tqdm>=4.0.0",
2222
"ml-collections==0.1.0",
23-
"protobuf==3.19.0",
2423
"jaxtyping",
2524
]
2625

tests/test_parameters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,5 +253,5 @@ def test_output(num_datapoints, likelihood):
253253
a_constrainers, a_unconstrainers = build_transforms(augmented_params)
254254
assert "test_param" in list(a_constrainers.keys())
255255
assert "test_param" in list(a_unconstrainers.keys())
256-
assert a_constrainers["test_param"](1.0) == 1.0
257-
assert a_unconstrainers["test_param"](1.0) == 1.0
256+
assert a_constrainers["test_param"](jnp.array([1.0])) == 1.0
257+
assert a_unconstrainers["test_param"](jnp.array([1.0])) == 1.0

0 commit comments

Comments
 (0)