Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/examples_advanced/equinox_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.backend import control_flow

jax.config.update("jax_platform_name", "cpu")


# -

# Overwrite the while-loop (via a context manager):
Expand Down
5 changes: 0 additions & 5 deletions docs/examples_advanced/parameter_estimation_blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,6 @@
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor

# +
# x64 precision
jax.config.update("jax_enable_x64", True)

# CPU
jax.config.update("jax_platform_name", "cpu")

# IVP examples in JAX
if not backend.has_been_selected:
Expand Down
2 changes: 0 additions & 2 deletions docs/examples_advanced/parameter_estimation_optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
# -


Expand Down
3 changes: 0 additions & 3 deletions docs/examples_basic/conditioning_on_zero_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)


# +
# Create an ODE problem
Expand Down
1 change: 0 additions & 1 deletion docs/examples_basic/dynamic_output_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
backend.select("jax") # ivp examples in jax


jax.config.update("jax_platform_name", "cpu")
# -


Expand Down
1 change: 0 additions & 1 deletion docs/examples_basic/second_order_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax

jax.config.update("jax_platform_name", "cpu")
# -

# Quick refresher: first-order ODEs
Expand Down
1 change: 0 additions & 1 deletion docs/examples_basic/taylor_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax

jax.config.update("jax_platform_name", "cpu")
# -

# We start by defining an ODE.
Expand Down
8 changes: 2 additions & 6 deletions probdiffeq/backend/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def flip(arr, /, axis=None):
return jnp.flip(arr, axis=axis)


def asarray(x, /):
return jnp.asarray(x)
def asarray(x, /, dtype=None):
return jnp.asarray(x, dtype=dtype)


def squeeze(arr, /):
Expand Down Expand Up @@ -140,10 +140,6 @@ def load(path, /):
return jnp.load(path, allow_pickle=True)


def allclose(a, b, *, atol=1e-8, rtol=1e-5):
return jnp.allclose(a, b, atol=atol, rtol=rtol)


def stack(list_of_arrays, /):
return jnp.stack(list_of_arrays)

Expand Down
41 changes: 33 additions & 8 deletions probdiffeq/backend/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
This is not good for extendability of the test suite.
"""

import jax
import jax.numpy as jnp
import jax.tree_util
import pytest
Expand All @@ -31,22 +32,46 @@ def fixture(name=None, scope="function"):
return pytest_cases.fixture(name=name, scope=scope)


def tree_all_allclose(tree1, tree2, **kwargs):
trees_is_allclose = tree_allclose(tree1, tree2, **kwargs)
def allclose(tree1, tree2, /, *, atol: float | None = None, rtol: float | None = None):
"""Check whether two pytrees are 'close' to each other.

In contrast to jax.numpy.allclose, this version:
- Works with pytrees (by comparing the structure and leaves)
- Uses different tolerances for single precision than for double precision.
(It adjusts atol and rtol to the floating-point precision of the leaves.)
"""
trees_is_allclose = _tree_allclose(tree1, tree2, atol=atol, rtol=rtol)
return jax.tree_util.tree_all(trees_is_allclose)


def tree_allclose(tree1, tree2, **kwargs):
def allclose_partial(*args):
return jnp.allclose(*args, **kwargs)
def _tree_allclose(tree1, tree2, /, *, atol, rtol):
def allclose_partial(t1, t2, /):
return _allclose(t1, t2, atol=atol, rtol=rtol)

return jax.tree_util.tree_map(allclose_partial, tree1, tree2)


def marginals_allclose(m1, m2, /, *, ssm):
def marginals_allclose(
m1, m2, /, *, ssm, atol: float | None = None, rtol: float | None = None
):
m1, c1 = ssm.stats.to_multivariate_normal(m1)
m2, c2 = ssm.stats.to_multivariate_normal(m2)

means_allclose = jnp.allclose(m1, m2)
covs_allclose = jnp.allclose(c1, c2)
means_allclose = _allclose(m1, m2, atol=atol, rtol=rtol)
covs_allclose = _allclose(c1, c2, atol=atol, rtol=rtol)
return jnp.logical_and(means_allclose, covs_allclose)


def _allclose(a, b, /, *, atol: float | None, rtol: float | None):
# promote to float-type to enable finfo.eps
a = jnp.asarray(1.0 * a)
b = jnp.asarray(1.0 * b)

# numpy.allclose uses defaults atol=1e-8 and rtol=1e-5;
# we mirror this as atol=sqrt(tol) and rtol slightly larger.
tol = jnp.sqrt(jnp.finfo(b.dtype).eps)
if atol is None:
atol = tol
if rtol is None:
rtol = 10 * tol
return jnp.allclose(a, b, atol=atol, rtol=rtol)
26 changes: 15 additions & 11 deletions probdiffeq/impl/_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,16 @@ def mahalanobis_norm_relative(self, u, /, rv):
return np.reshape(np.abs(mahalanobis) / np.sqrt(rv.mean.size), ())

def logpdf(self, u, /, rv):
# The cholesky factor is triangular, so we compute a cheap slogdet.
diagonal = linalg.diagonal_along_axis(rv.cholesky, axis1=-1, axis2=-2)
cholesky = linalg.qr_r(rv.cholesky.T).T
diagonal = linalg.diagonal_along_axis(cholesky, axis1=-1, axis2=-2)
slogdet = np.sum(np.log(np.abs(diagonal)))

dx = u - rv.mean
residual_white = linalg.solve_triangular(rv.cholesky.T, dx, trans="T")
x1 = linalg.vector_dot(residual_white, residual_white)
x2 = 2.0 * slogdet
x3 = u.size * np.log(np.pi() * 2)
return -0.5 * (x1 + x2 + x3)
residual_white = linalg.solve_triangular(cholesky, dx, lower=True, trans=0)
sqrnorm = linalg.vector_dot(residual_white, residual_white)

const = np.log(np.pi() * 2)
return -1 / 2 * sqrnorm - u.size / 2 * const - slogdet

def mean(self, rv):
return rv.mean
Expand Down Expand Up @@ -128,12 +128,14 @@ def logpdf(self, u, /, rv):
u = u[None, :]

def logpdf_scalar(x, r):
cholesky = linalg.qr_r(r.cholesky.T).T

dx = x - r.mean
w = linalg.solve_triangular(r.cholesky.T, dx, trans="T")
w = linalg.solve_triangular(cholesky.T, dx, trans="T")

maha_term = linalg.vector_dot(w, w)

diagonal = linalg.diagonal_along_axis(r.cholesky, axis1=-1, axis2=-2)
diagonal = linalg.diagonal_along_axis(cholesky, axis1=-1, axis2=-2)
slogdet = np.sum(np.log(np.abs(diagonal)))
logdet_term = 2.0 * slogdet
return -0.5 * (logdet_term + maha_term + x.size * np.log(np.pi() * 2))
Expand Down Expand Up @@ -195,12 +197,14 @@ def mahalanobis_norm_relative(self, u, /, rv):

def logpdf(self, u, /, rv):
def logpdf_scalar(x, r):
cholesky = linalg.qr_r(r.cholesky.T).T

dx = x - r.mean
w = linalg.solve_triangular(r.cholesky.T, dx, trans="T")
w = linalg.solve_triangular(cholesky.T, dx, trans="T")

maha_term = linalg.vector_dot(w, w)

diagonal = linalg.diagonal_along_axis(r.cholesky, axis1=-1, axis2=-2)
diagonal = linalg.diagonal_along_axis(cholesky, axis1=-1, axis2=-2)
slogdet = np.sum(np.log(np.abs(diagonal)))
logdet_term = 2.0 * slogdet
return -0.5 * (logdet_term + maha_term + x.size * np.log(np.pi() * 2))
Expand Down
13 changes: 0 additions & 13 deletions tests/conftest.py

This file was deleted.

20 changes: 10 additions & 10 deletions tests/test_backend/test_overwrite_control_flow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test that the control_flow can be updated by a user."""

from probdiffeq.backend import control_flow
from probdiffeq.backend import control_flow, testing
from probdiffeq.backend import numpy as np


Expand All @@ -10,12 +10,12 @@ def cumsum_step(carry, x):
return res, res

xs = np.arange(1.0, 11.0, step=2.0)
sum_total = 25
sum_total = 25.0
cumsum_total = np.asarray([1.0, 4.0, 9.0, 16.0, 25])

final, outputs = control_flow.scan(cumsum_step, init=0.0, xs=xs)
assert np.allclose(final, sum_total)
assert np.allclose(outputs, cumsum_total)
assert testing.allclose(final, sum_total)
assert testing.allclose(outputs, cumsum_total)

# Direct import;
# Do not use probdiffeq.backend since otherwise we recurse
Expand All @@ -26,17 +26,17 @@ def scan_that_adds_1(step, init, xs, reverse, length):

with control_flow.context_overwrite_scan(scan_that_adds_1):
final, outputs = control_flow.scan(cumsum_step, init=0.0, xs=xs)
assert np.allclose(final, sum_total + 1.0)
assert np.allclose(outputs, cumsum_total + 1.0)
assert testing.allclose(final, sum_total + 1.0)
assert testing.allclose(outputs, cumsum_total + 1.0)


def test_overwrite_while_loop_func():
def counter_step(x):
return x[0] + 1, x[1]

index, value = control_flow.while_loop(lambda s: s[0] < 10, counter_step, (0, 0.0))
assert np.allclose(index, 10)
assert np.allclose(value, 0.0)
assert testing.allclose(index, 10)
assert testing.allclose(value, 0.0)

# Direct import;
# Do not use probdiffeq.backend since otherwise we recurse
Expand All @@ -53,5 +53,5 @@ def while_loop_that_adds_1(cond_fun, body_fun, init_val):
index, value = control_flow.while_loop(
lambda s: s[0] < 10, counter_step, (0, 0.0)
)
assert np.allclose(index, 10)
assert np.allclose(value, 1.0) # instead of 0.
assert testing.allclose(index, 10)
assert testing.allclose(value, 1.0) # instead of 0.
12 changes: 9 additions & 3 deletions tests/test_impl/test_logpdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@
Necessary because the implementation has been faulty in the past. Never again.
"""

from probdiffeq.backend import functools, stats, testing
from probdiffeq.backend import functools, random, stats, testing, tree_util
from probdiffeq.backend import numpy as np
from probdiffeq.impl import impl


@testing.parametrize("fact", ["dense", "isotropic", "blockdiag"])
def test_logpdf(fact):
rv, ssm = random_variable(fact=fact)

u = np.ones_like(ssm.stats.mean(rv))

(mean_dense, cov_dense) = ssm.stats.to_multivariate_normal(rv)
u_dense = np.ones_like(mean_dense)

pdf1 = ssm.stats.logpdf(u, rv)
pdf2 = stats.multivariate_normal_logpdf(u_dense, mean_dense, cov_dense)
assert np.allclose(pdf1, pdf2)
assert testing.allclose(pdf1, pdf2)


@testing.parametrize("fact", ["dense", "isotropic", "blockdiag"])
Expand All @@ -37,4 +38,9 @@ def random_variable(fact):
output_scale = np.ones_like(ssm.prototypes.output_scale())
discretize = ssm.conditional.ibm_transitions(output_scale)
rv = discretize(0.1, output_scale)
return rv.noise, ssm

key = random.prng_key(seed=1)
noise_flat, unravel = tree_util.ravel_pytree(rv.noise)
noise_flat = random.normal(key, shape=noise_flat.shape)
noise = unravel(noise_flat)
return noise, ssm
2 changes: 1 addition & 1 deletion tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Taylor(containers.NamedTuple):
solution_fixed = ivpsolve.solve_fixed_grid(
init, grid=grid_adaptive, solver=solver, ssm=ssm
)
assert testing.tree_all_allclose(solution_adaptive, solution_fixed)
assert testing.allclose(solution_adaptive, solution_fixed)

# Assert u and u_std have matching shapes (that was wrong before)
u_shape = tree_util.tree_map(np.shape, solution_fixed.u)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ivpsolve/test_save_at_vs_save_every_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_save_at_result_matches_interpolated_adaptive_result(fact):
# Assert similarity

for ui, us in zip(u_interp, u_save_at):
assert testing.tree_all_allclose(ui, us)
assert testing.allclose(ui, us)

marginals_allclose_func = functools.partial(testing.marginals_allclose, ssm=ssm)
marginals_allclose_func = functools.vmap(marginals_allclose_func)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ivpsolve/test_save_every_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_python_loop_output_matches_reference(fact, strategy):

received = python_loop_solution(ivp, fact=fact, strategy_fun=strategy)
expected = reference_solution(ivp, received.t)
assert testing.tree_all_allclose(received.u[0], expected, rtol=1e-2)
assert testing.allclose(received.u[0], expected, rtol=1e-2)

# Assert u and u_std have matching shapes (that was wrong before)
u_shape = tree_util.tree_map(np.shape, received.u)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_terminal_values_identical(fact):
received = ivpsolve.solve_adaptive_terminal_values(
init, t0=t0, t1=t1, adaptive_solver=asolver, dt0=0.1, ssm=ssm
)
assert testing.tree_all_allclose(received, expected)
assert testing.allclose(received, expected)

# Assert u and u_std have matching shapes (that was wrong before)
u_shape = tree_util.tree_map(np.shape, received.u)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_ivpsolvers/test_calibration_mle_vs_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ def test_calibration_changes_the_posterior(uncalibrated_and_mle_solution):
output_scale_mle = mle_solution.output_scale

# Without a call to calibrate(), the posteriors are the same.
assert testing.tree_all_allclose(posterior_uncalibrated, posterior_mle)
assert not np.allclose(output_scale_uncalibrated, output_scale_mle)
assert testing.allclose(posterior_uncalibrated, posterior_mle)
assert not testing.allclose(output_scale_uncalibrated, output_scale_mle)

# With a call to calibrate(), the posteriors are different.
posterior_calibrated = stats.calibrate(posterior_mle, output_scale_mle, ssm=ssm)
assert not testing.tree_all_allclose(posterior_uncalibrated, posterior_calibrated)
assert not testing.allclose(posterior_uncalibrated, posterior_calibrated)
3 changes: 1 addition & 2 deletions tests/test_ivpsolvers/test_controllers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Test the controllers."""

from probdiffeq import ivpsolvers
from probdiffeq.backend import numpy as np
from probdiffeq.backend import testing


Expand All @@ -23,4 +22,4 @@ def test_equivalence_pi_vs_i(dt, error_power, num_applies):
dt_i = dt
for _ in range(num_applies):
dt_i, x_i = ctrl_i.apply(dt_i, x_i, error_power=error_power)
assert np.allclose(dt_i, dt_pi)
assert testing.allclose(dt_i, dt_pi)
2 changes: 1 addition & 1 deletion tests/test_ivpsolvers/test_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def fixture_solution(correction_impl, fact):
def test_terminal_value_simulation_matches_reference(solution):
expected = reference_solution(solution.t)
received = solution.u[0]
assert testing.tree_all_allclose(received, expected, rtol=1e-2)
assert testing.allclose(received, expected, rtol=1e-2)


@functools.jit
Expand Down
Loading