Skip to content

nathanaelbosch/parareax

Repository files navigation

Parareax: Parareal in JAX

CI Documentation codecov License: MIT Python 3.10+

A solver-agnostic JAX implementation of the parareal algorithm for time-parallel ODE solving.

Quick Links: Installation | Quick Start | Documentation

Parareal in a nutshell

Parareal is a parallel-in-time algorithm that combines a cheap coarse propagator $G: (y_{n-1}, t_{n-1}, t_n) \mapsto y_n$ with an expensive fine propagator $F: (y_{n-1}, t_{n-1}, t_n) \mapsto y_n$:

  1. Initialize using the coarse propagator to get $y_1^0, y_2^0, \dots, y_n^0$
  2. Iterate with the correction formula: $$y_n^k = G(y_{n-1}^k, t_{n-1}, t_n) + F(y_{n-1}^{k-1}, t_{n-1}, t_n) - G(y_{n-1}^{k-1}, t_{n-1}, t_n)$$
  3. Converge when the solution stabilizes or max iterations reached

The fine solver evaluations in step 2 can run in parallel, giving speedups while maintaining accuracy.

Installation

The package is not yet available on PyPI. To install:

git clone https://github.com/nathanaelbosch/parareax.git
cd parareax
pip install -e .

Quick Start

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.experimental.ode import odeint
from parareax import run_parareal

# Define the differential equation: logistic equation dy/dt = y(1-y)
def f(y, t):
    return y * (1 - y)

t0, t1 = 0, 1
y0 = jnp.array([1e-2])

# Define the coarse step function (simple Euler method)
@jax.jit
def coarse_step(y0, t0, t1):
    dy = f(y0, t0)
    return y0 + dy * (t1 - t0)

# Define the fine step function (using JAX's odeint)
@jax.jit
def fine_step(y0, t0, t1):
    return odeint(lambda y, t: f(y, t), y0, jnp.array([t0, t1]))[-1]

# Solve using Parareal
ts = jnp.linspace(0, 1, 101)  # 100 time intervals
solution, info = run_parareal(coarse_step, fine_step, y0=y0, ts=ts, tol=1e-14)
print("Parareal iterations:", info["iterations"])

# Compute comparison solutions using odeint
sol_fine = odeint(lambda y, t: f(y, t), y0, ts, rtol=1e-10, atol=1e-13)

# High-accuracy reference solution (using smaller internal steps)
ts_ref = jnp.linspace(t0, t1, 1001)  # More time points for higher accuracy
ref = odeint(lambda y, t: f(y, t), y0, ts_ref, rtol=1e-12, atol=1e-15)

# Compare errors
print("Fine solver error:", sol_fine[-1] - ref[-1])
print("Parareal error:", solution[-1] - ref[-1])

Output:

Parareal iterations: 6
Fine solver error: [1.57224928e-12]
Parareal error: [5.78604178e-12]

Parareal converged in just 6 iterations, requiring only 6 sequential steps (with 100 parallel fine solver calls per step) compared to 100 sequential steps in the standard approach. This parallel structure can provide speedups on expensive problems.

API Reference

run_parareal(
    coarse_step,   # Function: (y_start, t_start, t_end) -> y_end
    fine_step,     # Function: (y_start, t_start, t_end) -> y_end
    y0,            # Initial state array
    ts,            # Time points array
    maxiters=1000, # Maximum iterations
    tol=1e-9       # Convergence tolerance
)

Returns:

  • ys: Solution array of shape (len(ts), len(y0))
  • info: Dictionary with convergence information

About

Parareal in JAX

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors