Skip to content

Commit 8dccd1a

Browse files
authored
Run all tests in single precision (#835)
* Move allclose() from backend.numpy to backend.testing * There is only a single testing.allclose function now * All tests run in single precision now * Run all examples in single precision
1 parent af024d9 commit 8dccd1a

31 files changed

Lines changed: 130 additions & 142 deletions

docs/examples_advanced/equinox_while_loop.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727
from probdiffeq import ivpsolve, ivpsolvers, taylor
2828
from probdiffeq.backend import control_flow
2929

30-
jax.config.update("jax_platform_name", "cpu")
31-
32-
3330
# -
3431

3532
# Overwrite the while-loop (via a context manager):

docs/examples_advanced/parameter_estimation_blackjax.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,6 @@
135135
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor
136136

137137
# +
138-
# x64 precision
139-
jax.config.update("jax_enable_x64", True)
140-
141-
# CPU
142-
jax.config.update("jax_platform_name", "cpu")
143138

144139
# IVP examples in JAX
145140
if not backend.has_been_selected:

docs/examples_advanced/parameter_estimation_optax.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@
3737
if not backend.has_been_selected:
3838
backend.select("jax") # ivp examples in jax
3939

40-
jax.config.update("jax_enable_x64", True)
41-
jax.config.update("jax_platform_name", "cpu")
4240
# -
4341

4442

docs/examples_basic/conditioning_on_zero_residual.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@
3232
if not backend.has_been_selected:
3333
backend.select("jax") # ivp examples in jax
3434

35-
jax.config.update("jax_platform_name", "cpu")
36-
jax.config.update("jax_enable_x64", True)
37-
3835

3936
# +
4037
# Create an ODE problem

docs/examples_basic/dynamic_output_scales.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
backend.select("jax") # ivp examples in jax
4747

4848

49-
jax.config.update("jax_platform_name", "cpu")
5049
# -
5150

5251

docs/examples_basic/second_order_problems.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
if not backend.has_been_selected:
2929
backend.select("jax") # ivp examples in jax
3030

31-
jax.config.update("jax_platform_name", "cpu")
3231
# -
3332

3433
# Quick refresher: first-order ODEs

docs/examples_basic/taylor_coefficients.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
if not backend.has_been_selected:
3535
backend.select("jax") # ivp examples in jax
3636

37-
jax.config.update("jax_platform_name", "cpu")
3837
# -
3938

4039
# We start by defining an ODE.

probdiffeq/backend/numpy.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def flip(arr, /, axis=None):
5252
return jnp.flip(arr, axis=axis)
5353

5454

55-
def asarray(x, /):
56-
return jnp.asarray(x)
55+
def asarray(x, /, dtype=None):
56+
return jnp.asarray(x, dtype=dtype)
5757

5858

5959
def squeeze(arr, /):
@@ -140,10 +140,6 @@ def load(path, /):
140140
return jnp.load(path, allow_pickle=True)
141141

142142

143-
def allclose(a, b, *, atol=1e-8, rtol=1e-5):
144-
return jnp.allclose(a, b, atol=atol, rtol=rtol)
145-
146-
147143
def stack(list_of_arrays, /):
148144
return jnp.stack(list_of_arrays)
149145

probdiffeq/backend/testing.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
This is not good for extendability of the test suite.
1010
"""
1111

12+
import jax
1213
import jax.numpy as jnp
1314
import jax.tree_util
1415
import pytest
@@ -31,22 +32,46 @@ def fixture(name=None, scope="function"):
3132
return pytest_cases.fixture(name=name, scope=scope)
3233

3334

34-
def tree_all_allclose(tree1, tree2, **kwargs):
35-
trees_is_allclose = tree_allclose(tree1, tree2, **kwargs)
35+
def allclose(tree1, tree2, /, *, atol: float | None = None, rtol: float | None = None):
36+
"""Check whether two pytrees are 'close' to each other.
37+
38+
In contrast to jax.numpy.allclose, this version:
39+
- Works with pytrees (by comparing the structure and leaves)
40+
- Uses different tolerances for single precision than for double precision.
41+
(It adjusts atol and rtol to the floating-point precision of the leaves.)
42+
"""
43+
trees_is_allclose = _tree_allclose(tree1, tree2, atol=atol, rtol=rtol)
3644
return jax.tree_util.tree_all(trees_is_allclose)
3745

3846

39-
def tree_allclose(tree1, tree2, **kwargs):
40-
def allclose_partial(*args):
41-
return jnp.allclose(*args, **kwargs)
47+
def _tree_allclose(tree1, tree2, /, *, atol, rtol):
48+
def allclose_partial(t1, t2, /):
49+
return _allclose(t1, t2, atol=atol, rtol=rtol)
4250

4351
return jax.tree_util.tree_map(allclose_partial, tree1, tree2)
4452

4553

46-
def marginals_allclose(m1, m2, /, *, ssm):
54+
def marginals_allclose(
55+
m1, m2, /, *, ssm, atol: float | None = None, rtol: float | None = None
56+
):
4757
m1, c1 = ssm.stats.to_multivariate_normal(m1)
4858
m2, c2 = ssm.stats.to_multivariate_normal(m2)
4959

50-
means_allclose = jnp.allclose(m1, m2)
51-
covs_allclose = jnp.allclose(c1, c2)
60+
means_allclose = _allclose(m1, m2, atol=atol, rtol=rtol)
61+
covs_allclose = _allclose(c1, c2, atol=atol, rtol=rtol)
5262
return jnp.logical_and(means_allclose, covs_allclose)
63+
64+
65+
def _allclose(a, b, /, *, atol: float | None, rtol: float | None):
66+
# promote to float-type to enable finfo.eps
67+
a = jnp.asarray(1.0 * a)
68+
b = jnp.asarray(1.0 * b)
69+
70+
# numpy.allclose uses defaults atol=1e-8 and rtol=1e-5;
71+
# we mirror this as atol=sqrt(tol) and rtol slightly larger.
72+
tol = jnp.sqrt(jnp.finfo(b.dtype).eps)
73+
if atol is None:
74+
atol = tol
75+
if rtol is None:
76+
rtol = 10 * tol
77+
return jnp.allclose(a, b, atol=atol, rtol=rtol)

probdiffeq/impl/_stats.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,16 @@ def mahalanobis_norm_relative(self, u, /, rv):
6363
return np.reshape(np.abs(mahalanobis) / np.sqrt(rv.mean.size), ())
6464

6565
def logpdf(self, u, /, rv):
66-
# The cholesky factor is triangular, so we compute a cheap slogdet.
67-
diagonal = linalg.diagonal_along_axis(rv.cholesky, axis1=-1, axis2=-2)
66+
cholesky = linalg.qr_r(rv.cholesky.T).T
67+
diagonal = linalg.diagonal_along_axis(cholesky, axis1=-1, axis2=-2)
6868
slogdet = np.sum(np.log(np.abs(diagonal)))
6969

7070
dx = u - rv.mean
71-
residual_white = linalg.solve_triangular(rv.cholesky.T, dx, trans="T")
72-
x1 = linalg.vector_dot(residual_white, residual_white)
73-
x2 = 2.0 * slogdet
74-
x3 = u.size * np.log(np.pi() * 2)
75-
return -0.5 * (x1 + x2 + x3)
71+
residual_white = linalg.solve_triangular(cholesky, dx, lower=True, trans=0)
72+
sqrnorm = linalg.vector_dot(residual_white, residual_white)
73+
74+
const = np.log(np.pi() * 2)
75+
return -1 / 2 * sqrnorm - u.size / 2 * const - slogdet
7676

7777
def mean(self, rv):
7878
return rv.mean
@@ -128,12 +128,14 @@ def logpdf(self, u, /, rv):
128128
u = u[None, :]
129129

130130
def logpdf_scalar(x, r):
131+
cholesky = linalg.qr_r(r.cholesky.T).T
132+
131133
dx = x - r.mean
132-
w = linalg.solve_triangular(r.cholesky.T, dx, trans="T")
134+
w = linalg.solve_triangular(cholesky.T, dx, trans="T")
133135

134136
maha_term = linalg.vector_dot(w, w)
135137

136-
diagonal = linalg.diagonal_along_axis(r.cholesky, axis1=-1, axis2=-2)
138+
diagonal = linalg.diagonal_along_axis(cholesky, axis1=-1, axis2=-2)
137139
slogdet = np.sum(np.log(np.abs(diagonal)))
138140
logdet_term = 2.0 * slogdet
139141
return -0.5 * (logdet_term + maha_term + x.size * np.log(np.pi() * 2))
@@ -195,12 +197,14 @@ def mahalanobis_norm_relative(self, u, /, rv):
195197

196198
def logpdf(self, u, /, rv):
197199
def logpdf_scalar(x, r):
200+
cholesky = linalg.qr_r(r.cholesky.T).T
201+
198202
dx = x - r.mean
199-
w = linalg.solve_triangular(r.cholesky.T, dx, trans="T")
203+
w = linalg.solve_triangular(cholesky.T, dx, trans="T")
200204

201205
maha_term = linalg.vector_dot(w, w)
202206

203-
diagonal = linalg.diagonal_along_axis(r.cholesky, axis1=-1, axis2=-2)
207+
diagonal = linalg.diagonal_along_axis(cholesky, axis1=-1, axis2=-2)
204208
slogdet = np.sum(np.log(np.abs(diagonal)))
205209
logdet_term = 2.0 * slogdet
206210
return -0.5 * (logdet_term + maha_term + x.size * np.log(np.pi() * 2))

0 commit comments

Comments
 (0)