Skip to content

Commit bc4e8d6

Browse files
committed
All tests run in single precision now
1 parent 0447311 commit bc4e8d6

10 files changed

Lines changed: 42 additions & 53 deletions

File tree

probdiffeq/backend/numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def flip(arr, /, axis=None):
5252
return jnp.flip(arr, axis=axis)
5353

5454

55-
def asarray(x, /):
56-
return jnp.asarray(x)
55+
def asarray(x, /, dtype=None):
56+
return jnp.asarray(x, dtype=dtype)
5757

5858

5959
def squeeze(arr, /):

probdiffeq/backend/testing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,11 @@ def _allclose(a, b, /, *, atol: float | None, rtol: float | None):
6767
a = jnp.asarray(1.0 * a)
6868
b = jnp.asarray(1.0 * b)
6969

70+
# numpy.allclose uses defaults atol=1e-8 and rtol=1e-5;
71+
# we mirror this as atol=sqrt(tol) and rtol slightly larger.
7072
tol = jnp.sqrt(jnp.finfo(b.dtype).eps)
7173
if atol is None:
7274
atol = tol
7375
if rtol is None:
74-
rtol = tol
76+
rtol = 10 * tol
7577
return jnp.allclose(a, b, atol=atol, rtol=rtol)

probdiffeq/impl/_stats.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,16 @@ def mahalanobis_norm_relative(self, u, /, rv):
6363
return np.reshape(np.abs(mahalanobis) / np.sqrt(rv.mean.size), ())
6464

6565
def logpdf(self, u, /, rv):
66-
# The cholesky factor is triangular, so we compute a cheap slogdet.
67-
diagonal = linalg.diagonal_along_axis(rv.cholesky, axis1=-1, axis2=-2)
66+
cholesky = linalg.qr_r(rv.cholesky.T).T
67+
diagonal = linalg.diagonal_along_axis(cholesky, axis1=-1, axis2=-2)
6868
slogdet = np.sum(np.log(np.abs(diagonal)))
6969

7070
dx = u - rv.mean
71-
residual_white = linalg.solve_triangular(rv.cholesky.T, dx, trans="T")
72-
x1 = linalg.vector_dot(residual_white, residual_white)
73-
x2 = 2.0 * slogdet
74-
x3 = u.size * np.log(np.pi() * 2)
75-
return -0.5 * (x1 + x2 + x3)
71+
residual_white = linalg.solve_triangular(cholesky, dx, lower=True, trans=0)
72+
sqrnorm = linalg.vector_dot(residual_white, residual_white)
73+
74+
const = np.log(np.pi() * 2)
75+
return -1 / 2 * sqrnorm - u.size / 2 * const - slogdet
7676

7777
def mean(self, rv):
7878
return rv.mean
@@ -128,12 +128,14 @@ def logpdf(self, u, /, rv):
128128
u = u[None, :]
129129

130130
def logpdf_scalar(x, r):
131+
cholesky = linalg.qr_r(r.cholesky.T).T
132+
131133
dx = x - r.mean
132-
w = linalg.solve_triangular(r.cholesky.T, dx, trans="T")
134+
w = linalg.solve_triangular(cholesky.T, dx, trans="T")
133135

134136
maha_term = linalg.vector_dot(w, w)
135137

136-
diagonal = linalg.diagonal_along_axis(r.cholesky, axis1=-1, axis2=-2)
138+
diagonal = linalg.diagonal_along_axis(cholesky, axis1=-1, axis2=-2)
137139
slogdet = np.sum(np.log(np.abs(diagonal)))
138140
logdet_term = 2.0 * slogdet
139141
return -0.5 * (logdet_term + maha_term + x.size * np.log(np.pi() * 2))
@@ -195,12 +197,14 @@ def mahalanobis_norm_relative(self, u, /, rv):
195197

196198
def logpdf(self, u, /, rv):
197199
def logpdf_scalar(x, r):
200+
cholesky = linalg.qr_r(r.cholesky.T).T
201+
198202
dx = x - r.mean
199-
w = linalg.solve_triangular(r.cholesky.T, dx, trans="T")
203+
w = linalg.solve_triangular(cholesky.T, dx, trans="T")
200204

201205
maha_term = linalg.vector_dot(w, w)
202206

203-
diagonal = linalg.diagonal_along_axis(r.cholesky, axis1=-1, axis2=-2)
207+
diagonal = linalg.diagonal_along_axis(cholesky, axis1=-1, axis2=-2)
204208
slogdet = np.sum(np.log(np.abs(diagonal)))
205209
logdet_term = 2.0 * slogdet
206210
return -0.5 * (logdet_term + maha_term + x.size * np.log(np.pi() * 2))

tests/conftest.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

tests/test_impl/test_logpdfs.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
Necessary because the implementation has been faulty in the past. Never again.
44
"""
55

6-
from probdiffeq.backend import functools, stats, testing
6+
from probdiffeq.backend import functools, random, stats, testing, tree_util
77
from probdiffeq.backend import numpy as np
88
from probdiffeq.impl import impl
99

1010

1111
@testing.parametrize("fact", ["dense", "isotropic", "blockdiag"])
1212
def test_logpdf(fact):
1313
rv, ssm = random_variable(fact=fact)
14+
1415
u = np.ones_like(ssm.stats.mean(rv))
1516

1617
(mean_dense, cov_dense) = ssm.stats.to_multivariate_normal(rv)
@@ -37,4 +38,9 @@ def random_variable(fact):
3738
output_scale = np.ones_like(ssm.prototypes.output_scale())
3839
discretize = ssm.conditional.ibm_transitions(output_scale)
3940
rv = discretize(0.1, output_scale)
40-
return rv.noise, ssm
41+
42+
key = random.prng_key(seed=1)
43+
noise_flat, unravel = tree_util.ravel_pytree(rv.noise)
44+
noise_flat = random.normal(key, shape=noise_flat.shape)
45+
noise = unravel(noise_flat)
46+
return noise, ssm

tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,6 @@ def test_fixedpoint_smoother_equivalent_different_grid(solver_setup, solution_sm
104104
# Compare QOI and marginals
105105
marginals_allclose_func = functools.partial(testing.marginals_allclose, ssm=ssm)
106106
marginals_allclose_func = functools.vmap(marginals_allclose_func)
107+
107108
assert testing.allclose(u_fixedpoint, u_interp)
108109
assert np.all(marginals_allclose_func(marginals_fixedpoint, marginals_interp))

tests/test_taylor/data/generate_reference_solutions.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,6 @@
55
from probdiffeq.backend import numpy as np
66

77

8-
def set_environment():
9-
"""Set the environment (e.g., 64-bit precision).
10-
11-
The setup used to precompute references should match that of the other tests.
12-
"""
13-
# Test on CPU.
14-
config.update("platform_name", "cpu")
15-
16-
# Double precision
17-
config.update("enable_x64", True)
18-
19-
208
def three_body_first(num_derivatives_max=6):
219
vf, (u0,), (t0, _) = ode.ivp_three_body_1st()
2210
vf = functools.partial(vf, t=t0)
@@ -30,8 +18,8 @@ def van_der_pol_second(num_derivatives_max=6):
3018

3119

3220
if __name__ == "__main__":
33-
# 64-bit precision and the like
34-
set_environment()
21+
# Double precision
22+
config.update("enable_x64", True)
3523

3624
solution1 = three_body_first()
3725
np.save("./tests/test_taylor/data/three_body_first_solution.npy", solution1)

tests/test_taylor/test_exact_first_order.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def fixture_pb_with_solution():
3030

3131

3232
@testing.parametrize_with_cases("taylor_fun", cases=".", prefix="case_")
33-
@testing.parametrize("num", [1, 6])
33+
@testing.parametrize("num", [1, 4])
3434
def test_approximation_identical_to_reference(pb_with_solution, taylor_fun, num):
3535
(f, init), solution = pb_with_solution
3636

tests/test_taylor/test_inexact_first_order.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from probdiffeq.backend import numpy as np
66

77

8-
@testing.parametrize("num", [0, 1, 4])
8+
@testing.parametrize("num", [0, 1, 3])
99
@testing.parametrize("fact", ["isotropic", "dense", "blockdiag"])
1010
def test_initialised_correct_shape_and_values(num, fact):
1111
vf, (u0,), (t0, _) = ode.ivp_lotka_volterra()

tests/test_util/test_cholesky_util.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
These are so crucial and annoying to debug that they need their own test set.
44
"""
55

6-
from math import prod
7-
8-
from probdiffeq.backend import functools, linalg, testing, tree_util
6+
from probdiffeq.backend import functools, linalg, random, testing, tree_util
97
from probdiffeq.backend import numpy as np
108
from probdiffeq.util import cholesky_util
119

@@ -34,14 +32,16 @@ def cov(x):
3432
assert testing.allclose(cov(bw_noise), C1)
3533

3634

37-
@testing.parametrize("Cshape, HCshape", ([(3, 3), (2, 3)],))
38-
def test_revert_kernel_noisefree(Cshape, HCshape):
35+
@testing.parametrize("Cshape, Hshape", ([(3, 3), (2, 3)],))
36+
def test_revert_kernel_noisefree(Cshape, Hshape):
3937
C = _some_array(Cshape) + 1.0
40-
HC = _some_array(HCshape) + 2.0
38+
H = _some_array(Hshape) + 2.0
39+
HC = H @ C
4140

4241
S = HC @ HC.T
4342
K = C @ HC.T @ linalg.inv(S)
44-
C1 = C @ C.T - K @ S @ K.T
43+
44+
C1 = (np.eye(Cshape[0]) - K @ H) @ C @ C.T @ (np.eye(Cshape[0]) - K @ H).T
4545

4646
extra, (bw_noise, g) = cholesky_util.revert_conditional_noisefree(
4747
R_X_F=HC.T, R_X=C.T
@@ -56,7 +56,8 @@ def cov(x):
5656

5757

5858
def _some_array(shape):
59-
return np.arange(1.0, 1.0 + prod(shape)).reshape(shape)
59+
key = random.prng_key(seed=1)
60+
return random.normal(key, shape=shape)
6061

6162

6263
def test_sqrt_sum_square_scalar():

0 commit comments

Comments
 (0)