Skip to content

Commit edc374f

Browse files
Update benchmarks - fix #146
1 parent 2a18660 commit edc374f

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

benchmarks/gmres_fails_safely.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,21 @@
1313
# limitations under the License.
1414

1515
import functools as ft
16-
import sys
1716

17+
import equinox as eqx
18+
import equinox.internal as eqxi
1819
import jax
1920
import jax.numpy as jnp
2021
import jax.random as jr
2122
import jax.scipy as jsp
2223
import lineax as lx
2324

2425

25-
sys.path.append("../tests")
26-
from helpers import getkey, shaped_allclose # pyright: ignore
26+
getkey = eqxi.GetKey()
27+
28+
29+
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):
30+
return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)
2731

2832

2933
jax.config.update("jax_enable_x64", True)
@@ -48,7 +52,7 @@ def benchmark_jax(mat_size: int, *, key):
4852

4953
# info == 0.0 implies that the solve has succeeded.
5054
returned_failed = jnp.all(info != 0.0)
51-
actually_failed = not shaped_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4)
55+
actually_failed = not tree_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4)
5256

5357
assert actually_failed
5458

@@ -62,7 +66,7 @@ def benchmark_lx(mat_size: int, *, key):
6266
lx_soln = lx.linear_solve(op, b, lx.GMRES(atol=1e-5, rtol=1e-5), throw=False)
6367

6468
returned_failed = jnp.all(lx_soln.result != lx.RESULTS.successful)
65-
actually_failed = not shaped_allclose(lx_soln.value, true_x, atol=1e-4, rtol=1e-4)
69+
actually_failed = not tree_allclose(lx_soln.value, true_x, atol=1e-4, rtol=1e-4)
6670

6771
assert actually_failed
6872

benchmarks/solver_speeds.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import timeit
1818

1919
import equinox as eqx
20+
import equinox.internal as eqxi
2021
import jax
2122
import jax.numpy as jnp
2223
import jax.random as jr
@@ -25,12 +26,14 @@
2526

2627

2728
sys.path.append("../tests")
28-
from helpers import ( # pyright: ignore
29-
construct_matrix,
30-
getkey,
31-
has_tag,
32-
shaped_allclose,
33-
)
29+
from helpers import construct_matrix, has_tag # pyright: ignore[reportMissingImports]
30+
31+
32+
getkey = eqxi.GetKey()
33+
34+
35+
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8):
36+
return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)
3437

3538

3639
jax.config.update("jax_enable_x64", True)
@@ -171,7 +174,7 @@ def test_solvers(vmap_size, mat_size):
171174
batch_msg = f"batch of {vmap_size} problems"
172175

173176
lx_soln = bench_lx()
174-
if shaped_allclose(lx_soln, true_x, atol=1e-4, rtol=1e-4):
177+
if tree_allclose(lx_soln, true_x, atol=1e-4, rtol=1e-4):
175178
lx_solve_time = timeit.timeit(bench_lx, number=1)
176179

177180
print(
@@ -192,7 +195,7 @@ def test_solvers(vmap_size, mat_size):
192195
jax_solver = jax.jit(jax_solver)
193196
bench_jax = ft.partial(jax_solver, matrix, b)
194197
jax_soln = bench_jax()
195-
if shaped_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4):
198+
if tree_allclose(jax_soln, true_x, atol=1e-4, rtol=1e-4):
196199
jax_solve_time = timeit.timeit(bench_jax, number=1)
197200
print(
198201
f"JAX's {jax_name} solved {batch_msg} of "

0 commit comments

Comments
 (0)