A solver-agnostic JAX implementation of the parareal algorithm for time-parallel ODE solving.
Quick Links: Installation | Quick Start | Documentation
Parareal is a parallel-in-time algorithm that combines a cheap coarse propagator
-
Initialize using the coarse propagator to get
$y_1^0, y_2^0, \dots, y_n^0$ -
Iterate with the correction formula:
$$y_n^k = G(y_{n-1}^k, t_{n-1}, t_n) + F(y_{n-1}^{k-1}, t_{n-1}, t_n) - G(y_{n-1}^{k-1}, t_{n-1}, t_n)$$ - Converge when the solution stabilizes or max iterations reached
The fine solver evaluations in step 2 can run in parallel, giving speedups while maintaining accuracy.
The package is not yet available on PyPI. To install:
git clone https://github.com/nathanaelbosch/parareax.git
cd parareax
pip install -e .import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.experimental.ode import odeint
from parareax import run_parareal
# Define the differential equation: logistic equation dy/dt = y(1-y)
def f(y, t):
return y * (1 - y)
t0, t1 = 0, 1
y0 = jnp.array([1e-2])
# Define the coarse step function (simple Euler method)
@jax.jit
def coarse_step(y0, t0, t1):
dy = f(y0, t0)
return y0 + dy * (t1 - t0)
# Define the fine step function (using JAX's odeint)
@jax.jit
def fine_step(y0, t0, t1):
return odeint(lambda y, t: f(y, t), y0, jnp.array([t0, t1]))[-1]
# Solve using Parareal
ts = jnp.linspace(0, 1, 101) # 100 time intervals
solution, info = run_parareal(coarse_step, fine_step, y0=y0, ts=ts, tol=1e-14)
print("Parareal iterations:", info["iterations"])
# Compute comparison solutions using odeint
sol_fine = odeint(lambda y, t: f(y, t), y0, ts, rtol=1e-10, atol=1e-13)
# High-accuracy reference solution (using smaller internal steps)
ts_ref = jnp.linspace(t0, t1, 1001) # More time points for higher accuracy
ref = odeint(lambda y, t: f(y, t), y0, ts_ref, rtol=1e-12, atol=1e-15)
# Compare errors
print("Fine solver error:", sol_fine[-1] - ref[-1])
print("Parareal error:", solution[-1] - ref[-1])Output:
Parareal iterations: 6
Fine solver error: [1.57224928e-12]
Parareal error: [5.78604178e-12]
Parareal converged in just 6 iterations, requiring only 6 sequential steps (with 100 parallel fine solver calls per step) compared to 100 sequential steps in the standard approach. This parallel structure can provide speedups on expensive problems.
run_parareal(
coarse_step, # Function: (y_start, t_start, t_end) -> y_end
fine_step, # Function: (y_start, t_start, t_end) -> y_end
y0, # Initial state array
ts, # Time points array
maxiters=1000, # Maximum iterations
tol=1e-9 # Convergence tolerance
)Returns:
ys: Solution array of shape(len(ts), len(y0))info: Dictionary with convergence information