1717import timeit
1818
1919import equinox as eqx
20+ import equinox .internal as eqxi
2021import jax
2122import jax .numpy as jnp
2223import jax .random as jr
2526
2627
2728sys .path .append ("../tests" )
28- from helpers import ( # pyright: ignore
29- construct_matrix ,
30- getkey ,
31- has_tag ,
32- shaped_allclose ,
33- )
29+ from helpers import construct_matrix , has_tag # pyright: ignore[reportMissingImports]
30+
31+
32+ getkey = eqxi .GetKey ()
33+
34+
35+ def tree_allclose (x , y , * , rtol = 1e-5 , atol = 1e-8 ):
36+ return eqx .tree_equal (x , y , typematch = True , rtol = rtol , atol = atol )
3437
3538
3639jax .config .update ("jax_enable_x64" , True )
@@ -171,7 +174,7 @@ def test_solvers(vmap_size, mat_size):
171174 batch_msg = f"batch of { vmap_size } problems"
172175
173176 lx_soln = bench_lx ()
174- if shaped_allclose (lx_soln , true_x , atol = 1e-4 , rtol = 1e-4 ):
177+ if tree_allclose (lx_soln , true_x , atol = 1e-4 , rtol = 1e-4 ):
175178 lx_solve_time = timeit .timeit (bench_lx , number = 1 )
176179
177180 print (
@@ -192,7 +195,7 @@ def test_solvers(vmap_size, mat_size):
192195 jax_solver = jax .jit (jax_solver )
193196 bench_jax = ft .partial (jax_solver , matrix , b )
194197 jax_soln = bench_jax ()
195- if shaped_allclose (jax_soln , true_x , atol = 1e-4 , rtol = 1e-4 ):
198+ if tree_allclose (jax_soln , true_x , atol = 1e-4 , rtol = 1e-4 ):
196199 jax_solve_time = timeit .timeit (bench_jax , number = 1 )
197200 print (
198201 f"JAX's { jax_name } solved { batch_msg } of "
0 commit comments