Skip to content

Commit af27f0a

Browse files
Fix tests with jax 0.5.1 (#494)
* Bump tolerance on expected improvement utility * Relax get_batch uniqueness test * Surpress docstrings * Bound Flax * Deprecate QuadContourSet collections * Bump version * Bound mkdocstrings --------- Co-authored-by: Thomas Pinder <[email protected]>
1 parent 01087f5 commit af27f0a

File tree

5 files changed

+7
-9
lines changed

5 files changed

+7
-9
lines changed

examples/intro_to_gps.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,7 @@
218218
d_prob = d.prob(jnp.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])).reshape(
219219
xx.shape
220220
)
221-
cntf = a.contourf(xx, yy, jnp.exp(d_prob), levels=20, antialiased=True, cmap=cmap)
222-
for c in cntf.collections:
223-
c.set_edgecolor("face")
221+
cntf = a.contourf(xx, yy, jnp.exp(d_prob), levels=20, antialiased=True, cmap=cmap, edgecolor="face")
224222
a.set_xlim(-2.75, 2.75)
225223
a.set_ylim(-2.75, 2.75)
226224
samples = d.sample(seed=key, sample_shape=(5000,))

gpjax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
__description__ = "Didactic Gaussian processes in JAX"
4141
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
4242
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43-
__version__ = "0.9.3"
43+
__version__ = "0.9.4"
4444

4545
__all__ = [
4646
"base",

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ dependencies = [
3030
"beartype>0.16.1",
3131
"cola-ml==0.0.5",
3232
"jaxopt==0.8.2",
33-
"flax>=0.8.5",
33+
"flax<0.10.0",
3434
"numpy<2.0.0",
3535
]
3636

@@ -42,7 +42,7 @@ python = "3.10"
4242
dependencies = [
4343
"mkdocs>=1.5.3",
4444
"mkdocs-material>=9.5.12",
45-
"mkdocstrings[python]>=0.25.1",
45+
"mkdocstrings[python]<0.28.0",
4646
"mkdocs-jupyter>=0.24.3",
4747
"mkdocs-gen-files>=0.5.0",
4848
"mkdocs-literate-nav>=0.6.0",

tests/test_decision_making/test_utility_functions/test_expected_improvement.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,4 @@ def test_expected_improvement_utility_function_correct_values(
6464
eta = get_best_latent_observation_val(posterior, dataset)
6565
mc_ei = jnp.expand_dims(jnp.mean(jnp.maximum(eta - samples, 0), 0), -1)
6666
assert jnp.all(ei >= 0)
67-
assert jnp.allclose(ei, mc_ei, rtol=0.01)
67+
assert jnp.allclose(ei, mc_ei, rtol=0.08, atol=1e-6)

tests/test_fit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,5 +322,5 @@ def test_get_batch(n_data: int, n_dim: int, batch_size: int):
322322
assert New.n == batch_size
323323
assert New.X.shape[1:] == x.shape[1:]
324324
assert New.y.shape[1:] == y.shape[1:]
325-
assert (New.X != B.X).all()
326-
assert (New.y != B.y).all()
325+
assert jnp.sum(New.X == B.X) <= n_dim * batch_size / n_data
326+
assert jnp.sum(New.y == B.y) <= n_dim * batch_size / n_data

0 commit comments

Comments
 (0)