Skip to content

Commit 8c3fb96

Browse files
committed
Update the VdP benchmarK
1 parent 429a2fd commit 8c3fb96

1 file changed

Lines changed: 7 additions & 44 deletions

File tree

docs/examples_benchmarks/work-precision-vanderpol.py

Lines changed: 7 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
import timeit
2424
from collections.abc import Callable
2525

26-
import diffrax
2726
import jax
2827
import jax.numpy as jnp
2928
import matplotlib.pyplot as plt
29+
import numba
3030
import numpy as np
3131
import scipy.integrate
3232
import tqdm
@@ -37,7 +37,7 @@
3737
jax.config.update("jax_debug_nans", True)
3838

3939

40-
def main(start=2.0, stop=8.0, step=1.0, repeats=2, use_diffrax: bool = False):
40+
def main(start=3.0, stop=8.0, step=1.0, repeats=2):
4141
"""Run the script."""
4242
# Set up all the configs
4343
jax.config.update("jax_enable_x64", True)
@@ -60,21 +60,15 @@ def main(start=2.0, stop=8.0, step=1.0, repeats=2, use_diffrax: bool = False):
6060

6161
# Assemble algorithms
6262
algorithms = {
63-
"SciPy: 'Radau'": solver_scipy(method="Radau"),
64-
"SciPy: 'LSODA'": solver_scipy(method="LSODA"),
63+
"SciPy + numba: 'Radau'": solver_scipy(method="Radau"),
64+
"SciPy + numba: 'LSODA'": solver_scipy(method="LSODA"),
6565
r"ProbDiffEq: TS1($3$)": solver_probdiffeq(num_derivatives=3),
6666
r"ProbDiffEq: TS1($4$)": solver_probdiffeq(num_derivatives=4),
6767
r"ProbDiffEq: TS1($5$)": solver_probdiffeq(num_derivatives=5),
6868
}
69-
if use_diffrax:
70-
# TODO: this is a temporary fix because Diffrax doesn't work with JAX >= 0.7.0
71-
# Revisit in the near future.
72-
algorithms["Diffrax: Kvaerno5()"] = solver_diffrax(solver=diffrax.Kvaerno5())
73-
else:
74-
print("\nSkipped Diffrax.\n")
7569

7670
# Compute a reference solution
77-
reference = solver_probdiffeq(num_derivatives=4)(1e-10)
71+
reference = solver_scipy(method="Radau")(0.1 * tolerances[-1])
7872
precision_fun = rmse_absolute(reference)
7973

8074
# Compute all work-precision diagrams
@@ -100,6 +94,7 @@ def main(start=2.0, stop=8.0, step=1.0, repeats=2, use_diffrax: bool = False):
10094
def solve_ivp_once():
10195
"""Compute plotting-values for the IVP."""
10296

97+
@numba.jit(nopython=True)
10398
def vf_scipy(_t, u):
10499
"""Van-der-Pol dynamics as a first-order differential equation."""
105100
return np.asarray([u[1], 1e5 * ((1.0 - u[0] ** 2) * u[1] - u[0])])
@@ -171,42 +166,10 @@ def param_to_solution(tol):
171166
return param_to_solution
172167

173168

174-
def solver_diffrax(*, solver) -> Callable:
175-
"""Construct a solver that wraps Diffrax' solution routines."""
176-
177-
@diffrax.ODETerm
178-
@jax.jit
179-
def vf_diffrax(_t, u, _args):
180-
"""Van-der-Pol dynamics as a first-order differential equation."""
181-
return jnp.asarray([u[1], 1e5 * ((1.0 - u[0] ** 2) * u[1] - u[0])])
182-
183-
t0, t1 = 0.0, 3.0
184-
u0 = jnp.concatenate((jnp.atleast_1d(2.0), jnp.atleast_1d(0.0)))
185-
t0, t1 = (0.0, 6.3)
186-
187-
@jax.jit
188-
def param_to_solution(tol):
189-
controller = diffrax.PIDController(atol=1e-3 * tol, rtol=tol)
190-
saveat = diffrax.SaveAt(t0=False, t1=True, ts=None)
191-
solution = diffrax.diffeqsolve(
192-
vf_diffrax,
193-
y0=u0,
194-
t0=t0,
195-
t1=t1,
196-
saveat=saveat,
197-
stepsize_controller=controller,
198-
dt0=None,
199-
max_steps=10_000,
200-
solver=solver,
201-
)
202-
return jax.block_until_ready(solution.ys[0, 0])
203-
204-
return param_to_solution
205-
206-
207169
def solver_scipy(method: str) -> Callable:
208170
"""Construct a solver that wraps SciPy's solution routines."""
209171

172+
@numba.jit(nopython=True)
210173
def vf_scipy(_t, u):
211174
"""Van-der-Pol dynamics as a first-order differential equation."""
212175
return np.asarray([u[1], 1e5 * ((1.0 - u[0] ** 2) * u[1] - u[0])])

0 commit comments

Comments
 (0)