2323import timeit
2424from collections .abc import Callable
2525
26- import diffrax
2726import jax
2827import jax .numpy as jnp
2928import matplotlib .pyplot as plt
29+ import numba
3030import numpy as np
3131import scipy .integrate
3232import tqdm
3737jax .config .update ("jax_debug_nans" , True )
3838
3939
40- def main (start = 2 .0 , stop = 8.0 , step = 1.0 , repeats = 2 , use_diffrax : bool = False ):
40+ def main (start = 3 .0 , stop = 8.0 , step = 1.0 , repeats = 2 ):
4141 """Run the script."""
4242 # Set up all the configs
4343 jax .config .update ("jax_enable_x64" , True )
@@ -60,21 +60,15 @@ def main(start=2.0, stop=8.0, step=1.0, repeats=2, use_diffrax: bool = False):
6060
6161 # Assemble algorithms
6262 algorithms = {
63- "SciPy: 'Radau'" : solver_scipy (method = "Radau" ),
64- "SciPy: 'LSODA'" : solver_scipy (method = "LSODA" ),
63+ "SciPy + numba : 'Radau'" : solver_scipy (method = "Radau" ),
64+ "SciPy + numba : 'LSODA'" : solver_scipy (method = "LSODA" ),
6565 r"ProbDiffEq: TS1($3$)" : solver_probdiffeq (num_derivatives = 3 ),
6666 r"ProbDiffEq: TS1($4$)" : solver_probdiffeq (num_derivatives = 4 ),
6767 r"ProbDiffEq: TS1($5$)" : solver_probdiffeq (num_derivatives = 5 ),
6868 }
69- if use_diffrax :
70- # TODO: this is a temporary fix because Diffrax doesn't work with JAX >= 0.7.0
71- # Revisit in the near future.
72- algorithms ["Diffrax: Kvaerno5()" ] = solver_diffrax (solver = diffrax .Kvaerno5 ())
73- else :
74- print ("\n Skipped Diffrax.\n " )
7569
7670 # Compute a reference solution
77- reference = solver_probdiffeq ( num_derivatives = 4 )( 1e-10 )
71+ reference = solver_scipy ( method = "Radau" )( 0.1 * tolerances [ - 1 ] )
7872 precision_fun = rmse_absolute (reference )
7973
8074 # Compute all work-precision diagrams
@@ -100,6 +94,7 @@ def main(start=2.0, stop=8.0, step=1.0, repeats=2, use_diffrax: bool = False):
10094def solve_ivp_once ():
10195 """Compute plotting-values for the IVP."""
10296
97+ @numba .jit (nopython = True )
10398 def vf_scipy (_t , u ):
10499 """Van-der-Pol dynamics as a first-order differential equation."""
105100 return np .asarray ([u [1 ], 1e5 * ((1.0 - u [0 ] ** 2 ) * u [1 ] - u [0 ])])
@@ -171,42 +166,10 @@ def param_to_solution(tol):
171166 return param_to_solution
172167
173168
174- def solver_diffrax (* , solver ) -> Callable :
175- """Construct a solver that wraps Diffrax' solution routines."""
176-
177- @diffrax .ODETerm
178- @jax .jit
179- def vf_diffrax (_t , u , _args ):
180- """Van-der-Pol dynamics as a first-order differential equation."""
181- return jnp .asarray ([u [1 ], 1e5 * ((1.0 - u [0 ] ** 2 ) * u [1 ] - u [0 ])])
182-
183- t0 , t1 = 0.0 , 3.0
184- u0 = jnp .concatenate ((jnp .atleast_1d (2.0 ), jnp .atleast_1d (0.0 )))
185- t0 , t1 = (0.0 , 6.3 )
186-
187- @jax .jit
188- def param_to_solution (tol ):
189- controller = diffrax .PIDController (atol = 1e-3 * tol , rtol = tol )
190- saveat = diffrax .SaveAt (t0 = False , t1 = True , ts = None )
191- solution = diffrax .diffeqsolve (
192- vf_diffrax ,
193- y0 = u0 ,
194- t0 = t0 ,
195- t1 = t1 ,
196- saveat = saveat ,
197- stepsize_controller = controller ,
198- dt0 = None ,
199- max_steps = 10_000 ,
200- solver = solver ,
201- )
202- return jax .block_until_ready (solution .ys [0 , 0 ])
203-
204- return param_to_solution
205-
206-
207169def solver_scipy (method : str ) -> Callable :
208170 """Construct a solver that wraps SciPy's solution routines."""
209171
172+ @numba .jit (nopython = True )
210173 def vf_scipy (_t , u ):
211174 """Van-der-Pol dynamics as a first-order differential equation."""
212175 return np .asarray ([u [1 ], 1e5 * ((1.0 - u [0 ] ** 2 ) * u [1 ] - u [0 ])])
0 commit comments