diff --git a/docs/examples_advanced/equinox_while_loop.py b/docs/examples_advanced/equinox_while_loop.py index 4920cf4b1..43e3c7e58 100644 --- a/docs/examples_advanced/equinox_while_loop.py +++ b/docs/examples_advanced/equinox_while_loop.py @@ -27,9 +27,6 @@ from probdiffeq import ivpsolve, ivpsolvers, taylor from probdiffeq.backend import control_flow -jax.config.update("jax_platform_name", "cpu") - - # - # Overwrite the while-loop (via a context manager): diff --git a/docs/examples_advanced/parameter_estimation_blackjax.py b/docs/examples_advanced/parameter_estimation_blackjax.py index 48c8ab297..2a7c1ead5 100644 --- a/docs/examples_advanced/parameter_estimation_blackjax.py +++ b/docs/examples_advanced/parameter_estimation_blackjax.py @@ -135,11 +135,6 @@ from probdiffeq import ivpsolve, ivpsolvers, stats, taylor # + -# x64 precision -jax.config.update("jax_enable_x64", True) - -# CPU -jax.config.update("jax_platform_name", "cpu") # IVP examples in JAX if not backend.has_been_selected: diff --git a/docs/examples_advanced/parameter_estimation_optax.py b/docs/examples_advanced/parameter_estimation_optax.py index 0103a8bb4..45fe71df1 100644 --- a/docs/examples_advanced/parameter_estimation_optax.py +++ b/docs/examples_advanced/parameter_estimation_optax.py @@ -37,8 +37,6 @@ if not backend.has_been_selected: backend.select("jax") # ivp examples in jax -jax.config.update("jax_enable_x64", True) -jax.config.update("jax_platform_name", "cpu") # - diff --git a/docs/examples_basic/conditioning_on_zero_residual.py b/docs/examples_basic/conditioning_on_zero_residual.py index 16de47b00..cf3e2f901 100644 --- a/docs/examples_basic/conditioning_on_zero_residual.py +++ b/docs/examples_basic/conditioning_on_zero_residual.py @@ -32,9 +32,6 @@ if not backend.has_been_selected: backend.select("jax") # ivp examples in jax -jax.config.update("jax_platform_name", "cpu") -jax.config.update("jax_enable_x64", True) - # + # Create an ODE problem diff --git a/docs/examples_basic/dynamic_output_scales.py b/docs/examples_basic/dynamic_output_scales.py index b8e6c3321..2149360b0 100644 --- a/docs/examples_basic/dynamic_output_scales.py +++ b/docs/examples_basic/dynamic_output_scales.py @@ -46,7 +46,6 @@ backend.select("jax") # ivp examples in jax -jax.config.update("jax_platform_name", "cpu") # - diff --git a/docs/examples_basic/second_order_problems.py b/docs/examples_basic/second_order_problems.py index a96a3b629..6c5849389 100644 --- a/docs/examples_basic/second_order_problems.py +++ b/docs/examples_basic/second_order_problems.py @@ -28,7 +28,6 @@ if not backend.has_been_selected: backend.select("jax") # ivp examples in jax -jax.config.update("jax_platform_name", "cpu") # - # Quick refresher: first-order ODEs diff --git a/docs/examples_basic/taylor_coefficients.py b/docs/examples_basic/taylor_coefficients.py index 17f32190e..7101641bf 100644 --- a/docs/examples_basic/taylor_coefficients.py +++ b/docs/examples_basic/taylor_coefficients.py @@ -34,7 +34,6 @@ if not backend.has_been_selected: backend.select("jax") # ivp examples in jax -jax.config.update("jax_platform_name", "cpu") # - # We start by defining an ODE. diff --git a/probdiffeq/backend/numpy.py b/probdiffeq/backend/numpy.py index 83516ac9c..a0fd24ece 100644 --- a/probdiffeq/backend/numpy.py +++ b/probdiffeq/backend/numpy.py @@ -52,8 +52,8 @@ def flip(arr, /, axis=None): return jnp.flip(arr, axis=axis) -def asarray(x, /): - return jnp.asarray(x) +def asarray(x, /, dtype=None): + return jnp.asarray(x, dtype=dtype) def squeeze(arr, /): @@ -140,10 +140,6 @@ def load(path, /): return jnp.load(path, allow_pickle=True) -def allclose(a, b, *, atol=1e-8, rtol=1e-5): - return jnp.allclose(a, b, atol=atol, rtol=rtol) - - def stack(list_of_arrays, /): return jnp.stack(list_of_arrays) diff --git a/probdiffeq/backend/testing.py b/probdiffeq/backend/testing.py index a8004f331..d037ccea4 100644 --- a/probdiffeq/backend/testing.py +++ b/probdiffeq/backend/testing.py @@ -9,6 +9,7 @@ This is not good for extendability of the test suite. """ +import jax import jax.numpy as jnp import jax.tree_util import pytest @@ -31,22 +32,46 @@ def fixture(name=None, scope="function"): return pytest_cases.fixture(name=name, scope=scope) -def tree_all_allclose(tree1, tree2, **kwargs): - trees_is_allclose = tree_allclose(tree1, tree2, **kwargs) +def allclose(tree1, tree2, /, *, atol: float | None = None, rtol: float | None = None): + """Check whether two pytrees are 'close' to each other. + + In contrast to jax.numpy.allclose, this version: + - Works with pytrees (by comparing the structure and leaves) + - Uses different tolerances for single precision than for double precision. + (It adjusts atol and rtol to the floating-point precision of the leaves.) + """ + trees_is_allclose = _tree_allclose(tree1, tree2, atol=atol, rtol=rtol) return jax.tree_util.tree_all(trees_is_allclose) -def tree_allclose(tree1, tree2, **kwargs): - def allclose_partial(*args): - return jnp.allclose(*args, **kwargs) +def _tree_allclose(tree1, tree2, /, *, atol, rtol): + def allclose_partial(t1, t2, /): + return _allclose(t1, t2, atol=atol, rtol=rtol) return jax.tree_util.tree_map(allclose_partial, tree1, tree2) -def marginals_allclose(m1, m2, /, *, ssm): +def marginals_allclose( + m1, m2, /, *, ssm, atol: float | None = None, rtol: float | None = None +): m1, c1 = ssm.stats.to_multivariate_normal(m1) m2, c2 = ssm.stats.to_multivariate_normal(m2) - means_allclose = jnp.allclose(m1, m2) - covs_allclose = jnp.allclose(c1, c2) + means_allclose = _allclose(m1, m2, atol=atol, rtol=rtol) + covs_allclose = _allclose(c1, c2, atol=atol, rtol=rtol) return jnp.logical_and(means_allclose, covs_allclose) + + +def _allclose(a, b, /, *, atol: float | None, rtol: float | None): + # promote to float-type to enable finfo.eps + a = jnp.asarray(1.0 * a) + b = jnp.asarray(1.0 * b) + + # numpy.allclose uses defaults atol=1e-8 and rtol=1e-5; + # we mirror this as atol=sqrt(tol) and rtol slightly larger. + tol = jnp.sqrt(jnp.finfo(b.dtype).eps) + if atol is None: + atol = tol + if rtol is None: + rtol = 10 * tol + return jnp.allclose(a, b, atol=atol, rtol=rtol) diff --git a/probdiffeq/impl/_stats.py b/probdiffeq/impl/_stats.py index 5996f9521..a974bf015 100644 --- a/probdiffeq/impl/_stats.py +++ b/probdiffeq/impl/_stats.py @@ -63,16 +63,16 @@ def mahalanobis_norm_relative(self, u, /, rv): return np.reshape(np.abs(mahalanobis) / np.sqrt(rv.mean.size), ()) def logpdf(self, u, /, rv): - # The cholesky factor is triangular, so we compute a cheap slogdet. - diagonal = linalg.diagonal_along_axis(rv.cholesky, axis1=-1, axis2=-2) + cholesky = linalg.qr_r(rv.cholesky.T).T + diagonal = linalg.diagonal_along_axis(cholesky, axis1=-1, axis2=-2) slogdet = np.sum(np.log(np.abs(diagonal))) dx = u - rv.mean - residual_white = linalg.solve_triangular(rv.cholesky.T, dx, trans="T") - x1 = linalg.vector_dot(residual_white, residual_white) - x2 = 2.0 * slogdet - x3 = u.size * np.log(np.pi() * 2) - return -0.5 * (x1 + x2 + x3) + residual_white = linalg.solve_triangular(cholesky, dx, lower=True, trans=0) + sqrnorm = linalg.vector_dot(residual_white, residual_white) + + const = np.log(np.pi() * 2) + return -1 / 2 * sqrnorm - u.size / 2 * const - slogdet def mean(self, rv): return rv.mean @@ -128,12 +128,14 @@ def logpdf(self, u, /, rv): u = u[None, :] def logpdf_scalar(x, r): + cholesky = linalg.qr_r(r.cholesky.T).T + dx = x - r.mean - w = linalg.solve_triangular(r.cholesky.T, dx, trans="T") + w = linalg.solve_triangular(cholesky.T, dx, trans="T") maha_term = linalg.vector_dot(w, w) - diagonal = linalg.diagonal_along_axis(r.cholesky, axis1=-1, axis2=-2) + diagonal = linalg.diagonal_along_axis(cholesky, axis1=-1, axis2=-2) slogdet = np.sum(np.log(np.abs(diagonal))) logdet_term = 2.0 * slogdet return -0.5 * (logdet_term + maha_term + x.size * np.log(np.pi() * 2)) @@ -195,12 +197,14 @@ def mahalanobis_norm_relative(self, u, /, rv): def logpdf(self, u, /, rv): def logpdf_scalar(x, r): + cholesky = linalg.qr_r(r.cholesky.T).T + dx = x - r.mean - w = linalg.solve_triangular(r.cholesky.T, dx, trans="T") + w = linalg.solve_triangular(cholesky.T, dx, trans="T") maha_term = linalg.vector_dot(w, w) - diagonal = linalg.diagonal_along_axis(r.cholesky, axis1=-1, axis2=-2) + diagonal = linalg.diagonal_along_axis(cholesky, axis1=-1, axis2=-2) slogdet = np.sum(np.log(np.abs(diagonal))) logdet_term = 2.0 * slogdet return -0.5 * (logdet_term + maha_term + x.size * np.log(np.pi() * 2)) diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 628b2155f..000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Test-setup.""" - -from probdiffeq.backend import config, warnings - -# All warnings shall be errors -warnings.filterwarnings("error") - -# Test on CPU. -config.update("platform_name", "cpu") - -# Double precision -# Needed for equivalence tests for smoothers. -config.update("enable_x64", True) diff --git a/tests/test_backend/test_overwrite_control_flow.py b/tests/test_backend/test_overwrite_control_flow.py index 39fa3a9f6..c15958644 100644 --- a/tests/test_backend/test_overwrite_control_flow.py +++ b/tests/test_backend/test_overwrite_control_flow.py @@ -1,6 +1,6 @@ """Test that the control_flow can be updated by a user.""" -from probdiffeq.backend import control_flow +from probdiffeq.backend import control_flow, testing from probdiffeq.backend import numpy as np @@ -10,12 +10,12 @@ def cumsum_step(carry, x): return res, res xs = np.arange(1.0, 11.0, step=2.0) - sum_total = 25 + sum_total = 25.0 cumsum_total = np.asarray([1.0, 4.0, 9.0, 16.0, 25]) final, outputs = control_flow.scan(cumsum_step, init=0.0, xs=xs) - assert np.allclose(final, sum_total) - assert np.allclose(outputs, cumsum_total) + assert testing.allclose(final, sum_total) + assert testing.allclose(outputs, cumsum_total) # Direct import; # Do not use probdiffeq.backend since otherwise we recurse @@ -26,8 +26,8 @@ def scan_that_adds_1(step, init, xs, reverse, length): with control_flow.context_overwrite_scan(scan_that_adds_1): final, outputs = control_flow.scan(cumsum_step, init=0.0, xs=xs) - assert np.allclose(final, sum_total + 1.0) - assert np.allclose(outputs, cumsum_total + 1.0) + assert testing.allclose(final, sum_total + 1.0) + assert testing.allclose(outputs, cumsum_total + 1.0) def test_overwrite_while_loop_func(): @@ -35,8 +35,8 @@ def counter_step(x): return x[0] + 1, x[1] index, value = control_flow.while_loop(lambda s: s[0] < 10, counter_step, (0, 0.0)) - assert np.allclose(index, 10) - assert np.allclose(value, 0.0) + assert testing.allclose(index, 10) + assert testing.allclose(value, 0.0) # Direct import; # Do not use probdiffeq.backend since otherwise we recurse @@ -53,5 +53,5 @@ def while_loop_that_adds_1(cond_fun, body_fun, init_val): index, value = control_flow.while_loop( lambda s: s[0] < 10, counter_step, (0, 0.0) ) - assert np.allclose(index, 10) - assert np.allclose(value, 1.0) # instead of 0. + assert testing.allclose(index, 10) + assert testing.allclose(value, 1.0) # instead of 0. diff --git a/tests/test_impl/test_logpdfs.py b/tests/test_impl/test_logpdfs.py index 866ee4eb3..66b36f9bf 100644 --- a/tests/test_impl/test_logpdfs.py +++ b/tests/test_impl/test_logpdfs.py @@ -3,7 +3,7 @@ Necessary because the implementation has been faulty in the past. Never again. """ -from probdiffeq.backend import functools, stats, testing +from probdiffeq.backend import functools, random, stats, testing, tree_util from probdiffeq.backend import numpy as np from probdiffeq.impl import impl @@ -11,6 +11,7 @@ @testing.parametrize("fact", ["dense", "isotropic", "blockdiag"]) def test_logpdf(fact): rv, ssm = random_variable(fact=fact) + u = np.ones_like(ssm.stats.mean(rv)) (mean_dense, cov_dense) = ssm.stats.to_multivariate_normal(rv) @@ -18,7 +19,7 @@ def test_logpdf(fact): pdf1 = ssm.stats.logpdf(u, rv) pdf2 = stats.multivariate_normal_logpdf(u_dense, mean_dense, cov_dense) - assert np.allclose(pdf1, pdf2) + assert testing.allclose(pdf1, pdf2) @testing.parametrize("fact", ["dense", "isotropic", "blockdiag"]) @@ -37,4 +38,9 @@ def random_variable(fact): output_scale = np.ones_like(ssm.prototypes.output_scale()) discretize = ssm.conditional.ibm_transitions(output_scale) rv = discretize(0.1, output_scale) - return rv.noise, ssm + + key = random.prng_key(seed=1) + noise_flat, unravel = tree_util.ravel_pytree(rv.noise) + noise_flat = random.normal(key, shape=noise_flat.shape) + noise = unravel(noise_flat) + return noise, ssm diff --git a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py index 193e6b4fd..4d6a029a9 100644 --- a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py @@ -32,7 +32,7 @@ class Taylor(containers.NamedTuple): solution_fixed = ivpsolve.solve_fixed_grid( init, grid=grid_adaptive, solver=solver, ssm=ssm ) - assert testing.tree_all_allclose(solution_adaptive, solution_fixed) + assert testing.allclose(solution_adaptive, solution_fixed) # Assert u and u_std have matching shapes (that was wrong before) u_shape = tree_util.tree_map(np.shape, solution_fixed.u) diff --git a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py index b4a1edff6..9650f84de 100644 --- a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py @@ -40,7 +40,7 @@ def test_save_at_result_matches_interpolated_adaptive_result(fact): # Assert similarity for ui, us in zip(u_interp, u_save_at): - assert testing.tree_all_allclose(ui, us) + assert testing.allclose(ui, us) marginals_allclose_func = functools.partial(testing.marginals_allclose, ssm=ssm) marginals_allclose_func = functools.vmap(marginals_allclose_func) diff --git a/tests/test_ivpsolve/test_save_every_step.py b/tests/test_ivpsolve/test_save_every_step.py index 255ee034e..ccc83f8ca 100644 --- a/tests/test_ivpsolve/test_save_every_step.py +++ b/tests/test_ivpsolve/test_save_every_step.py @@ -14,7 +14,7 @@ def test_python_loop_output_matches_reference(fact, strategy): received = python_loop_solution(ivp, fact=fact, strategy_fun=strategy) expected = reference_solution(ivp, received.t) - assert testing.tree_all_allclose(received.u[0], expected, rtol=1e-2) + assert testing.allclose(received.u[0], expected, rtol=1e-2) # Assert u and u_std have matching shapes (that was wrong before) u_shape = tree_util.tree_map(np.shape, received.u) diff --git a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py index dd6c975a1..b5d2d8a3e 100644 --- a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py @@ -26,7 +26,7 @@ def test_terminal_values_identical(fact): received = ivpsolve.solve_adaptive_terminal_values( init, t0=t0, t1=t1, adaptive_solver=asolver, dt0=0.1, ssm=ssm ) - assert testing.tree_all_allclose(received, expected) + assert testing.allclose(received, expected) # Assert u and u_std have matching shapes (that was wrong before) u_shape = tree_util.tree_map(np.shape, received.u) diff --git a/tests/test_ivpsolvers/test_calibration_mle_vs_none.py b/tests/test_ivpsolvers/test_calibration_mle_vs_none.py index 51e19a65e..b9b539081 100644 --- a/tests/test_ivpsolvers/test_calibration_mle_vs_none.py +++ b/tests/test_ivpsolvers/test_calibration_mle_vs_none.py @@ -119,9 +119,9 @@ def test_calibration_changes_the_posterior(uncalibrated_and_mle_solution): output_scale_mle = mle_solution.output_scale # Without a call to calibrate(), the posteriors are the same. - assert testing.tree_all_allclose(posterior_uncalibrated, posterior_mle) - assert not np.allclose(output_scale_uncalibrated, output_scale_mle) + assert testing.allclose(posterior_uncalibrated, posterior_mle) + assert not testing.allclose(output_scale_uncalibrated, output_scale_mle) # With a call to calibrate(), the posteriors are different. posterior_calibrated = stats.calibrate(posterior_mle, output_scale_mle, ssm=ssm) - assert not testing.tree_all_allclose(posterior_uncalibrated, posterior_calibrated) + assert not testing.allclose(posterior_uncalibrated, posterior_calibrated) diff --git a/tests/test_ivpsolvers/test_controllers.py b/tests/test_ivpsolvers/test_controllers.py index 5bce5ee3b..5ab110cda 100644 --- a/tests/test_ivpsolvers/test_controllers.py +++ b/tests/test_ivpsolvers/test_controllers.py @@ -1,7 +1,6 @@ """Test the controllers.""" from probdiffeq import ivpsolvers -from probdiffeq.backend import numpy as np from probdiffeq.backend import testing @@ -23,4 +22,4 @@ def test_equivalence_pi_vs_i(dt, error_power, num_applies): dt_i = dt for _ in range(num_applies): dt_i, x_i = ctrl_i.apply(dt_i, x_i, error_power=error_power) - assert np.allclose(dt_i, dt_pi) + assert testing.allclose(dt_i, dt_pi) diff --git a/tests/test_ivpsolvers/test_corrections.py b/tests/test_ivpsolvers/test_corrections.py index 598d31491..816e78340 100644 --- a/tests/test_ivpsolvers/test_corrections.py +++ b/tests/test_ivpsolvers/test_corrections.py @@ -57,7 +57,7 @@ def fixture_solution(correction_impl, fact): def test_terminal_value_simulation_matches_reference(solution): expected = reference_solution(solution.t) received = solution.u[0] - assert testing.tree_all_allclose(received, expected, rtol=1e-2) + assert testing.allclose(received, expected, rtol=1e-2) @functools.jit diff --git a/tests/test_ivpsolvers/test_cubature_equivalence.py b/tests/test_ivpsolvers/test_cubature_equivalence.py index c90212982..c05977775 100644 --- a/tests/test_ivpsolvers/test_cubature_equivalence.py +++ b/tests/test_ivpsolvers/test_cubature_equivalence.py @@ -1,7 +1,6 @@ """Test equivalences between cubature rules.""" from probdiffeq import ivpsolvers -from probdiffeq.backend import numpy as np from probdiffeq.backend import testing @@ -12,9 +11,9 @@ def test_third_order_spherical_vs_unscented_transform_scalar_input(): tos_points, tos_weights = tos.points, tos.weights_sqrtm ut_points, ut_weights = ut.points, ut.weights_sqrtm for x, y in [(ut_weights, tos_weights), (ut_points, tos_points)]: - assert np.allclose(x[:1], y[:1]) - assert np.allclose(x[1], 0.0) - assert np.allclose(x[2:], y[1:]) + assert testing.allclose(x[:1], y[:1]) + assert testing.allclose(x[1], 0.0) + assert testing.allclose(x[2:], y[1:]) @testing.parametrize("n", [4]) @@ -25,9 +24,9 @@ def test_third_order_spherical_vs_unscented_transform(n): tos_points, tos_weights = tos.points, tos.weights_sqrtm ut_points, ut_weights = ut.points, ut.weights_sqrtm for x, y in [(ut_weights, tos_weights), (ut_points, tos_points)]: - assert np.allclose(x[:n], y[:n]) - assert np.allclose(x[n], 0.0) - assert np.allclose(x[n + 1 :], y[n:]) + assert testing.allclose(x[:n], y[:n]) + assert testing.allclose(x[n], 0.0) + assert testing.allclose(x[n + 1 :], y[n:]) # todo: test for gauss-hermite? Do we need one? (we wrap scipy's rules anyway...) diff --git a/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py b/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py index f9fc1b744..0ebffe9d6 100644 --- a/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py +++ b/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py @@ -44,7 +44,7 @@ def fixture_smoother_solution(solver_setup): def test_compare_filter_smoother_rmse(filter_solution, smoother_solution): - assert np.allclose(filter_solution.t, smoother_solution.t) # sanity check + assert testing.allclose(filter_solution.t, smoother_solution.t) # sanity check reference = _reference_solution(filter_solution.t) u_fi = functools.vmap(lambda s: tree_util.ravel_pytree(s)[0])(filter_solution.u[0]) @@ -58,7 +58,7 @@ def test_compare_filter_smoother_rmse(filter_solution, smoother_solution): # I would like to compare filter & smoother RMSE. but this test is too unreliable, # so we simply assert that both are "comparable". - assert np.allclose(filter_rmse, smoother_rmse, atol=0.0, rtol=1.0) + assert testing.allclose(filter_rmse, smoother_rmse, atol=0.0, rtol=1.0) # The error should be small, otherwise the test makes little sense assert filter_rmse < 0.01 diff --git a/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py b/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py index d5b592b42..2914870bd 100644 --- a/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py +++ b/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py @@ -50,20 +50,20 @@ def test_fixedpoint_smoother_equivalent_same_grid(solver_setup, solution_smoothe ) sol_fp, sol_sm = solution_fixedpoint, solution_smoother # alias for brevity - assert testing.tree_all_allclose(sol_fp.t, sol_sm.t) - assert testing.tree_all_allclose(sol_fp.u, sol_sm.u) - assert testing.tree_all_allclose(sol_fp.u_std, sol_sm.u_std) - assert testing.tree_all_allclose(sol_fp.marginals, sol_sm.marginals) - assert testing.tree_all_allclose(sol_fp.output_scale, sol_sm.output_scale) - assert testing.tree_all_allclose(sol_fp.num_steps, sol_sm.num_steps) - assert testing.tree_all_allclose(sol_fp.posterior.init, sol_sm.posterior.init) + assert testing.allclose(sol_fp.t, sol_sm.t) + assert testing.allclose(sol_fp.u, sol_sm.u) + assert testing.allclose(sol_fp.u_std, sol_sm.u_std) + assert testing.allclose(sol_fp.marginals, sol_sm.marginals) + assert testing.allclose(sol_fp.output_scale, sol_sm.output_scale) + assert testing.allclose(sol_fp.num_steps, sol_sm.num_steps) + assert testing.allclose(sol_fp.posterior.init, sol_sm.posterior.init) # The backward conditionals use different parametrisations # but implement the same transitions cond_fp, cond_sm = sol_fp.posterior.conditional, sol_sm.posterior.conditional cond_fp = functools.vmap(sol_fp.ssm.conditional.preconditioner_apply)(cond_fp) cond_sm = functools.vmap(sol_sm.ssm.conditional.preconditioner_apply)(cond_sm) - assert testing.tree_all_allclose(cond_fp, cond_sm) + assert testing.allclose(cond_fp, cond_sm) def test_fixedpoint_smoother_equivalent_different_grid(solver_setup, solution_smoother): @@ -104,5 +104,6 @@ def test_fixedpoint_smoother_equivalent_different_grid(solver_setup, solution_sm # Compare QOI and marginals marginals_allclose_func = functools.partial(testing.marginals_allclose, ssm=ssm) marginals_allclose_func = functools.vmap(marginals_allclose_func) - assert testing.tree_all_allclose(u_fixedpoint, u_interp) + + assert testing.allclose(u_fixedpoint, u_interp) assert np.all(marginals_allclose_func(marginals_fixedpoint, marginals_interp)) diff --git a/tests/test_stats/test_offgrid_marginals.py b/tests/test_stats/test_offgrid_marginals.py index 24091161e..9ee7afbe4 100644 --- a/tests/test_stats/test_offgrid_marginals.py +++ b/tests/test_stats/test_offgrid_marginals.py @@ -25,11 +25,11 @@ def test_filter_marginals_close_only_to_left_boundary(fact): for u1, u2 in zip(u, sol.u): u1_ = tree_util.tree_map(lambda s: s[0], u1) u2_ = tree_util.tree_map(lambda s: s[-2], u2) - assert testing.tree_all_allclose(u1_, u2_, atol=1e-3, rtol=1e-3) + assert testing.allclose(u1_, u2_, atol=1e-3, rtol=1e-3) u1_ = tree_util.tree_map(lambda s: s[-1], u1) u2_ = tree_util.tree_map(lambda s: s[-1], u2) - assert not testing.tree_all_allclose(u1_, u2_, atol=1e-3, rtol=1e-3) + assert not testing.allclose(u1_, u2_, atol=1e-3, rtol=1e-3) @testing.parametrize("fact", ["isotropic", "dense", "blockdiag"]) @@ -53,8 +53,8 @@ def test_smoother_marginals_close_to_both_boundaries(fact): for u1, u2 in zip(u, sol.u): u1_ = tree_util.tree_map(lambda s: s[0], u1) u2_ = tree_util.tree_map(lambda s: s[-2], u2) - assert testing.tree_all_allclose(u1_, u2_, atol=1e-3, rtol=1e-3) + assert testing.allclose(u1_, u2_, atol=1e-3, rtol=1e-3) u1_ = tree_util.tree_map(lambda s: s[-1], u1) u2_ = tree_util.tree_map(lambda s: s[-1], u2) - assert testing.tree_all_allclose(u1_, u2_, atol=1e-3, rtol=1e-3) + assert testing.allclose(u1_, u2_, atol=1e-3, rtol=1e-3) diff --git a/tests/test_stats/test_sample.py b/tests/test_stats/test_sample.py index b22c20c92..d475036af 100644 --- a/tests/test_stats/test_sample.py +++ b/tests/test_stats/test_sample.py @@ -34,5 +34,5 @@ def test_sample_shape(approximation, shape): u_terminal_shape = tree_util.tree_map(lambda x: shape + x[-1].shape, u) u_inner_shape = tree_util.tree_map(lambda x: shape + x[:-1].shape, u) - assert testing.tree_all_allclose(i_shape, u_terminal_shape) - assert testing.tree_all_allclose(s_shape, u_inner_shape) + assert testing.allclose(i_shape, u_terminal_shape) + assert testing.allclose(s_shape, u_inner_shape) diff --git a/tests/test_taylor/data/generate_reference_solutions.py b/tests/test_taylor/data/generate_reference_solutions.py index 2fd2f0a60..bb874e507 100644 --- a/tests/test_taylor/data/generate_reference_solutions.py +++ b/tests/test_taylor/data/generate_reference_solutions.py @@ -5,18 +5,6 @@ from probdiffeq.backend import numpy as np -def set_environment(): - """Set the environment (e.g., 64-bit precision). - - The setup used to precompute references should match that of the other tests. - """ - # Test on CPU. - config.update("platform_name", "cpu") - - # Double precision - config.update("enable_x64", True) - - def three_body_first(num_derivatives_max=6): vf, (u0,), (t0, _) = ode.ivp_three_body_1st() vf = functools.partial(vf, t=t0) @@ -30,8 +18,8 @@ def van_der_pol_second(num_derivatives_max=6): if __name__ == "__main__": - # 64-bit precision and the like - set_environment() + # Double precision + config.update("enable_x64", True) solution1 = three_body_first() np.save("./tests/test_taylor/data/three_body_first_solution.npy", solution1) diff --git a/tests/test_taylor/test_affine_recursion.py b/tests/test_taylor/test_affine_recursion.py index 0255efcca..dfef9b195 100644 --- a/tests/test_taylor/test_affine_recursion.py +++ b/tests/test_taylor/test_affine_recursion.py @@ -12,7 +12,7 @@ def test_affine_recursion(num, num_derivatives_max): f, init, solution = _affine_problem(num_derivatives_max) derivatives = taylor.odejet_affine(f, init, num=num) assert len(derivatives) == num + 1 - assert testing.tree_all_allclose(derivatives, solution[: len(derivatives)]) + assert testing.allclose(derivatives, solution[: len(derivatives)]) def _affine_problem(n): diff --git a/tests/test_taylor/test_exact_first_order.py b/tests/test_taylor/test_exact_first_order.py index a70270729..c93818696 100644 --- a/tests/test_taylor/test_exact_first_order.py +++ b/tests/test_taylor/test_exact_first_order.py @@ -30,13 +30,13 @@ def fixture_pb_with_solution(): @testing.parametrize_with_cases("taylor_fun", cases=".", prefix="case_") -@testing.parametrize("num", [1, 6]) +@testing.parametrize("num", [1, 4]) def test_approximation_identical_to_reference(pb_with_solution, taylor_fun, num): (f, init), solution = pb_with_solution derivatives = taylor_fun(f, init, num=num) assert len(derivatives) == num + 1 - assert testing.tree_all_allclose(derivatives, list(solution[: len(derivatives)])) + assert testing.allclose(derivatives, list(solution[: len(derivatives)])) @testing.parametrize("num_doublings", [1, 2]) @@ -46,4 +46,4 @@ def test_approximation_identical_to_reference_doubling(pb_with_solution, num_dou derivatives = taylor.odejet_doubling_unroll(f, init, num_doublings=num_doublings) assert len(derivatives) == np.sum(2 ** np.arange(0, num_doublings + 1)) - assert testing.tree_all_allclose(derivatives, list(solution[: len(derivatives)])) + assert testing.allclose(derivatives, list(solution[: len(derivatives)])) diff --git a/tests/test_taylor/test_exact_higher_order.py b/tests/test_taylor/test_exact_higher_order.py index 3f466cc7e..6daa9f989 100644 --- a/tests/test_taylor/test_exact_higher_order.py +++ b/tests/test_taylor/test_exact_higher_order.py @@ -31,4 +31,4 @@ def test_approximation_identical_to_reference(pb_with_solution, taylor_fun, num) derivatives = taylor_fun(f, init, num=num) assert len(derivatives) == num + 2 - assert testing.tree_all_allclose(derivatives, list(solution[: len(derivatives)])) + assert testing.allclose(derivatives, list(solution[: len(derivatives)])) diff --git a/tests/test_taylor/test_inexact_first_order.py b/tests/test_taylor/test_inexact_first_order.py index b412abb5c..12cefd7d6 100644 --- a/tests/test_taylor/test_inexact_first_order.py +++ b/tests/test_taylor/test_inexact_first_order.py @@ -5,7 +5,7 @@ from probdiffeq.backend import numpy as np -@testing.parametrize("num", [0, 1, 4]) +@testing.parametrize("num", [0, 1, 3]) @testing.parametrize("fact", ["isotropic", "dense", "blockdiag"]) def test_initialised_correct_shape_and_values(num, fact): vf, (u0,), (t0, _) = ode.ivp_lotka_volterra() @@ -17,9 +17,7 @@ def test_initialised_correct_shape_and_values(num, fact): rk_starter = taylor.runge_kutta_starter(dt=0.01, prior=prior, ssm=ssm, num=num) derivatives = rk_starter(vf, (u0,), t=t0) assert len(derivatives) == 1 + num - assert testing.tree_all_allclose(derivatives[:1], solution[:1], rtol=1e-1) + assert testing.allclose(derivatives[:1], solution[:1], rtol=1e-1) if num > 1: - assert testing.tree_all_allclose( - derivatives[: num - 1], solution[: num - 1], rtol=1e-1 - ) + assert testing.allclose(derivatives[: num - 1], solution[: num - 1], rtol=1e-1) diff --git a/tests/test_util/test_cholesky_util.py b/tests/test_util/test_cholesky_util.py index f64483538..f70416247 100644 --- a/tests/test_util/test_cholesky_util.py +++ b/tests/test_util/test_cholesky_util.py @@ -3,9 +3,7 @@ These are so crucial and annoying to debug that they need their own test set. """ -from math import prod - -from probdiffeq.backend import functools, linalg, testing, tree_util +from probdiffeq.backend import functools, linalg, random, testing, tree_util from probdiffeq.backend import numpy as np from probdiffeq.util import cholesky_util @@ -29,19 +27,21 @@ def test_revert_conditional(HCshape, Cshape, Xshape): def cov(x): return x.T @ x - assert np.allclose(cov(extra), S) - assert np.allclose(g, K) - assert np.allclose(cov(bw_noise), C1) + assert testing.allclose(cov(extra), S) + assert testing.allclose(g, K) + assert testing.allclose(cov(bw_noise), C1) -@testing.parametrize("Cshape, HCshape", ([(3, 3), (2, 3)],)) -def test_revert_kernel_noisefree(Cshape, HCshape): +@testing.parametrize("Cshape, Hshape", ([(3, 3), (2, 3)],)) +def test_revert_kernel_noisefree(Cshape, Hshape): C = _some_array(Cshape) + 1.0 - HC = _some_array(HCshape) + 2.0 + H = _some_array(Hshape) + 2.0 + HC = H @ C S = HC @ HC.T K = C @ HC.T @ linalg.inv(S) - C1 = C @ C.T - K @ S @ K.T + + C1 = (np.eye(Cshape[0]) - K @ H) @ C @ C.T @ (np.eye(Cshape[0]) - K @ H).T extra, (bw_noise, g) = cholesky_util.revert_conditional_noisefree( R_X_F=HC.T, R_X=C.T @@ -50,13 +50,14 @@ def test_revert_kernel_noisefree(Cshape, HCshape): def cov(x): return x.T @ x - assert np.allclose(cov(extra), S) - assert np.allclose(g, K) - assert np.allclose(cov(bw_noise), C1) + assert testing.allclose(cov(extra), S) + assert testing.allclose(g, K) + assert testing.allclose(cov(bw_noise), C1) def _some_array(shape): - return np.arange(1.0, 1.0 + prod(shape)).reshape(shape) + key = random.prng_key(seed=1) + return random.normal(key, shape=shape) def test_sqrt_sum_square_scalar(): @@ -65,7 +66,7 @@ def test_sqrt_sum_square_scalar(): c = 5.0 expected = np.sqrt(a**2 + b**2 + c**2) received = cholesky_util.sqrt_sum_square_scalar(a, b, c) - assert np.allclose(expected, received) + assert testing.allclose(expected, received) def test_sqrt_sum_square_error(): @@ -121,7 +122,7 @@ def triu_via_qr_r(x, y, z): a, b, c = 3.0, 4.0, 5.0 expected = triu_via_naive_arithmetic_and_autograd(a, b, c) received = triu_via_qr_r(a, b, c) - assert np.allclose(expected, received) + assert testing.allclose(expected, received) def test_sqrt_sum_square_scalar_derivative_value_test_at_origin(): @@ -144,7 +145,7 @@ def triu_via_qr_r(x, y, z): a, b, c = 0.0, 0.0, 0.0 expected = triu_via_naive_arithmetic_and_autograd(a, b, c) received = triu_via_qr_r(a, b, c) - assert np.allclose(expected, received) + assert testing.allclose(expected, received) def _tree_is_free_of_nans(tree):