99This is not good for extendability of the test suite.
1010"""
1111
12+ import jax
1213import jax .numpy as jnp
1314import jax .tree_util
1415import pytest
@@ -31,22 +32,46 @@ def fixture(name=None, scope="function"):
3132 return pytest_cases .fixture (name = name , scope = scope )
3233
3334
34- def tree_all_allclose (tree1 , tree2 , ** kwargs ):
35- trees_is_allclose = tree_allclose (tree1 , tree2 , ** kwargs )
35+ def allclose (tree1 , tree2 , / , * , atol : float | None = None , rtol : float | None = None ):
36+ """Check whether two pytrees are 'close' to each other.
37+
38+ In contrast to jax.numpy.allclose, this version:
39+ - Works with pytrees (by comparing the structure and leaves)
40+ - Uses different tolerances for single precision than for double precision.
41+ (It adjusts atol and rtol to the floating-point precision of the leaves.)
42+ """
43+ trees_is_allclose = _tree_allclose (tree1 , tree2 , atol = atol , rtol = rtol )
3644 return jax .tree_util .tree_all (trees_is_allclose )
3745
3846
39- def tree_allclose (tree1 , tree2 , ** kwargs ):
40- def allclose_partial (* args ):
41- return jnp . allclose ( * args , ** kwargs )
47+ def _tree_allclose (tree1 , tree2 , / , * , atol , rtol ):
48+ def allclose_partial (t1 , t2 , / ):
49+ return _allclose ( t1 , t2 , atol = atol , rtol = rtol )
4250
4351 return jax .tree_util .tree_map (allclose_partial , tree1 , tree2 )
4452
4553
46- def marginals_allclose (m1 , m2 , / , * , ssm ):
54+ def marginals_allclose (
55+ m1 , m2 , / , * , ssm , atol : float | None = None , rtol : float | None = None
56+ ):
4757 m1 , c1 = ssm .stats .to_multivariate_normal (m1 )
4858 m2 , c2 = ssm .stats .to_multivariate_normal (m2 )
4959
50- means_allclose = jnp . allclose (m1 , m2 )
51- covs_allclose = jnp . allclose (c1 , c2 )
60+ means_allclose = _allclose (m1 , m2 , atol = atol , rtol = rtol )
61+ covs_allclose = _allclose (c1 , c2 , atol = atol , rtol = rtol )
5262 return jnp .logical_and (means_allclose , covs_allclose )
63+
64+
65+ def _allclose (a , b , / , * , atol : float | None , rtol : float | None ):
66+ # promote to float-type to enable finfo.eps
67+ a = jnp .asarray (1.0 * a )
68+ b = jnp .asarray (1.0 * b )
69+
70+ # numpy.allclose uses defaults atol=1e-8 and rtol=1e-5;
71+ # we mirror this as atol=sqrt(tol) and rtol slightly larger.
72+ tol = jnp .sqrt (jnp .finfo (b .dtype ).eps )
73+ if atol is None :
74+ atol = tol
75+ if rtol is None :
76+ rtol = 10 * tol
77+ return jnp .allclose (a , b , atol = atol , rtol = rtol )
0 commit comments