Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/examples_advanced/equinox_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def vf(y, /, *, t): # noqa: ARG001

strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
error = probdiffeq.error_residual_std(constraint=ts0, prior=ibm, ssm=ssm)
solve_adaptive = ivpsolve.solve_adaptive_terminal_values(
solver=solver, errorest=errorest, while_loop=while_loop
solver=solver, error=error, while_loop=while_loop
)

def simulate(init_val):
Expand Down
4 changes: 2 additions & 2 deletions docs/examples_advanced/parameter_estimation_blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ def solve_adaptive(theta, *, save_at):
# Create a probabilistic solver
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (theta,), num=2)
init, _ibm, _ssm = probdiffeq.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
solve = ivpsolve.solve_adaptive_save_at(solver=solver, errorest=errorest)
error = probdiffeq.error_residual_std(constraint=ts0, prior=ibm, ssm=ssm)
solve = ivpsolve.solve_adaptive_save_at(solver=solver, error=error)
return solve(init, save_at=save_at, dt0=0.1, atol=1e-4, rtol=1e-2)


Expand Down
4 changes: 2 additions & 2 deletions docs/examples_basic/conditioning_on_zero_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def vector_field(y, /, *, t):
ts1 = probdiffeq.constraint_ode_ts1(vector_field, ssm=ssm)
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
error = probdiffeq.error_residual_std(constraint=ts1, prior=ibm, ssm=ssm)

dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,))
solve = ivpsolve.solve_adaptive_save_at(solver=solver, errorest=errorest)
solve = ivpsolve.solve_adaptive_save_at(solver=solver, error=error)
sol = solve(init, save_at=ts, dt0=dt0, atol=1e-1, rtol=1e-1)
markov_seq_posterior = sol.solution_full

Expand Down
12 changes: 6 additions & 6 deletions docs/examples_basic/custom_information_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def hamiltonian_2nd(u, du):
solver_1st = probdiffeq.solver_mle(
strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
solve = ivpsolve.solve_adaptive_save_at(solver=solver_1st, errorest=errorest)
error = probdiffeq.error_residual_std(constraint=ts1, prior=ibm, ssm=ssm)
solve = ivpsolve.solve_adaptive_save_at(solver=solver_1st, error=error)

# -

Expand Down Expand Up @@ -165,11 +165,11 @@ def root(u, du, ddu, /, *, t):

# +

error_norm = probdiffeq.errorest_error_norm_rms_then_scale()
errorest = probdiffeq.errorest_local_residual_cached(
prior=ibm, ssm=ssm, error_norm=error_norm
error_norm = probdiffeq.error_norm_rms_then_scale()
error = probdiffeq.error_residual_std(
constraint=ts1, prior=ibm, ssm=ssm, error_norm=error_norm
)
solve = ivpsolve.solve_adaptive_save_at(solver=solver_2nd, errorest=errorest)
solve = ivpsolve.solve_adaptive_save_at(solver=solver_2nd, error=error)

sol_2 = jax.jit(solve)(init, save_at=save_at, atol=1e-3, rtol=1e-1)
ham_2 = jax.vmap(hamiltonian_2nd)(sol_2.u.mean[0], sol_2.u.mean[1])
Expand Down
4 changes: 2 additions & 2 deletions docs/examples_basic/posterior_uncertainties.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def vf(y, /, *, t): # noqa: ARG001
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)

solver = probdiffeq.solver_mle(strategy=strategy, prior=ibm, constraint=ts, ssm=ssm)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
solve = ivpsolve.solve_adaptive_save_at(solver=solver, errorest=errorest)
error = probdiffeq.error_residual_std(constraint=ts, prior=ibm, ssm=ssm)
solve = ivpsolve.solve_adaptive_save_at(solver=solver, error=error)


# -
Expand Down
8 changes: 4 additions & 4 deletions docs/examples_basic/second_order_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def vf_1(y, /, *, t):
solver_1st = probdiffeq.solver_mle(
strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm
)
errorest_1st = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
error_1st = probdiffeq.error_residual_std(constraint=ts0, prior=ibm, ssm=ssm)


# -
Expand All @@ -71,7 +71,7 @@ def vf_1(y, /, *, t):


save_at = jnp.linspace(t0, t1, endpoint=True, num=250)
solve = ivpsolve.solve_adaptive_save_at(solver=solver_1st, errorest=errorest_1st)
solve = ivpsolve.solve_adaptive_save_at(solver=solver_1st, error=error_1st)
solution = jax.jit(solve)(init, save_at=save_at, atol=1e-5, rtol=1e-5)
plt.plot(solution.u.mean[0][:, 0], solution.u.mean[0][:, 1], marker=".")
plt.show()
Expand Down Expand Up @@ -104,15 +104,15 @@ def vf_2(y, dy, /, *, t):
solver_2nd = probdiffeq.solver_mle(
strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm
)
errorest_2nd = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
error_2nd = probdiffeq.error_residual_std(constraint=ts0, prior=ibm, ssm=ssm)

# -


# +


solve = ivpsolve.solve_adaptive_save_at(solver=solver_2nd, errorest=errorest_2nd)
solve = ivpsolve.solve_adaptive_save_at(solver=solver_2nd, error=error_2nd)
solution = jax.jit(solve)(init, save_at=save_at, atol=1e-5, rtol=1e-5)

plt.plot(solution.u.mean[0][:, 0], solution.u.mean[0][:, 1], marker=".")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ def vf(y, /, *, t): # noqa: ARG001
solver = probdiffeq.solver_dynamic(
ssm=ssm, strategy=strategy, prior=ibm, constraint=ts
)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
error = probdiffeq.error_residual_std(constraint=ts, prior=ibm, ssm=ssm)

# Solve the ODE
save_at = jnp.linspace(t0, t1, num=5, endpoint=True)
simulate = simulator(save_at=save_at, errorest=errorest, solver=solver)
simulate = simulator(save_at=save_at, error=error, solver=solver)
(u, u_std) = simulate(init)

_fig, axes = plt.subplots(
Expand Down Expand Up @@ -86,12 +86,12 @@ def vf(y, /, *, t): # noqa: ARG001
# +


def simulator(save_at, errorest, solver):
def simulator(save_at, error, solver):
"""Simulate a PDE."""

@jax.jit
def solve(init):
solve = ivpsolve.solve_adaptive_save_at(errorest=errorest, solver=solver)
solve = ivpsolve.solve_adaptive_save_at(error=error, solver=solver)
solution = solve(init, save_at=save_at, atol=1e-4, rtol=1e-2)
return (solution.u.mean[0], solution.u.std[0])

Expand Down
4 changes: 2 additions & 2 deletions docs/examples_basic/taylor_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def solve(tc):
strategy=strategy, prior=prior, constraint=ts0, ssm=ssm
)
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
errorest = probdiffeq.errorest_local_residual_cached(prior=prior, ssm=ssm)
solve = ivpsolve.solve_adaptive_save_at(solver=solver, errorest=errorest)
error = probdiffeq.error_residual_std(constraint=ts0, prior=prior, ssm=ssm)
solve = ivpsolve.solve_adaptive_save_at(solver=solver, error=error)
return solve(init, save_at=ts, atol=1e-2, rtol=1e-2)


Expand Down
10 changes: 4 additions & 6 deletions docs/examples_benchmarks/convergence-rates-lotka-volterra.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,13 @@ def param_to_solution(tol):
# Build a solver
init, ibm, ssm = probdiffeq.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
strategy = probdiffeq.strategy_filter(ssm=ssm)
corr = probdiffeq.constraint_ode_ts1(vf_probdiffeq, ssm=ssm)
solver = probdiffeq.solver(
strategy=strategy, prior=ibm, constraint=corr, ssm=ssm
)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
ts = probdiffeq.constraint_ode_ts1(vf_probdiffeq, ssm=ssm)
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts, ssm=ssm)
error = probdiffeq.error_residual_std(constraint=ts, prior=ibm, ssm=ssm)

control = ivpsolve.control_proportional_integral()
solve = ivpsolve.solve_adaptive_terminal_values(
solver=solver, errorest=errorest, control=control
solver=solver, error=error, control=control
)

# Solve
Expand Down
4 changes: 2 additions & 2 deletions docs/examples_benchmarks/work-precision-hires.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ def param_to_solution(tol):
solver = probdiffeq.solver_dynamic(
strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
error = probdiffeq.error_residual_std(constraint=ts1, prior=ibm, ssm=ssm)

control = ivpsolve.control_proportional_integral()
solve = ivpsolve.solve_adaptive_terminal_values(
solver=solver, clip_dt=True, control=control, errorest=errorest
solver=solver, clip_dt=True, control=control, error=error
)

# Solve
Expand Down
8 changes: 4 additions & 4 deletions docs/examples_benchmarks/work-precision-lotka-volterra.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,15 @@ def param_to_solution(tol):
tcoeffs, ssm_fact=implementation
)
strategy = probdiffeq.strategy_filter(ssm=ssm)
corr = constraint(vf_probdiffeq, ssm=ssm)
ts = constraint(vf_probdiffeq, ssm=ssm)
solver = probdiffeq.solver_mle(
strategy=strategy, prior=ibm, constraint=corr, ssm=ssm
strategy=strategy, prior=ibm, constraint=ts, ssm=ssm
)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
error = probdiffeq.error_residual_std(constraint=ts, prior=ibm, ssm=ssm)

control = ivpsolve.control_proportional_integral()
solve = ivpsolve.solve_adaptive_terminal_values(
errorest=errorest, solver=solver, control=control
error=error, solver=solver, control=control
)
dt0 = ivpsolve.dt0(vf_auto, (u0,))

Expand Down
8 changes: 4 additions & 4 deletions docs/examples_benchmarks/work-precision-pleiades.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,16 @@ def param_to_solution(tol):
init, ibm, ssm = probdiffeq.prior_wiener_integrated(
tcoeffs, ssm_fact="isotropic"
)
ts0_or_ts1 = constraint_ode_fun(vf_probdiffeq, ssm=ssm)
ts = constraint_ode_fun(vf_probdiffeq, ssm=ssm)
strategy = probdiffeq.strategy_filter(ssm=ssm)
solver = probdiffeq.solver_dynamic(
strategy=strategy, prior=ibm, constraint=ts0_or_ts1, ssm=ssm
strategy=strategy, prior=ibm, constraint=ts, ssm=ssm
)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
error = probdiffeq.error_residual_std(constraint=ts, prior=ibm, ssm=ssm)

control = ivpsolve.control_proportional_integral()
solve = ivpsolve.solve_adaptive_terminal_values(
solver=solver, errorest=errorest, control=control
solver=solver, error=error, control=control
)

# Solve
Expand Down
59 changes: 11 additions & 48 deletions docs/examples_benchmarks/work-precision-vanderpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import timeit
from collections.abc import Callable

import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numba
import numpy as np
import scipy.integrate
import tqdm
Expand All @@ -37,7 +37,7 @@
jax.config.update("jax_debug_nans", True)


def main(start=2.0, stop=8.0, step=1.0, repeats=2, use_diffrax: bool = False):
def main(start=3.0, stop=8.0, step=1.0, repeats=2):
"""Run the script."""
# Set up all the configs
jax.config.update("jax_enable_x64", True)
Expand All @@ -60,21 +60,15 @@ def main(start=2.0, stop=8.0, step=1.0, repeats=2, use_diffrax: bool = False):

# Assemble algorithms
algorithms = {
"SciPy: 'Radau'": solver_scipy(method="Radau"),
"SciPy: 'LSODA'": solver_scipy(method="LSODA"),
"SciPy + numba: 'Radau'": solver_scipy(method="Radau"),
"SciPy + numba: 'LSODA'": solver_scipy(method="LSODA"),
r"ProbDiffEq: TS1($3$)": solver_probdiffeq(num_derivatives=3),
r"ProbDiffEq: TS1($4$)": solver_probdiffeq(num_derivatives=4),
r"ProbDiffEq: TS1($5$)": solver_probdiffeq(num_derivatives=5),
}
if use_diffrax:
# TODO: this is a temporary fix because Diffrax doesn't work with JAX >= 0.7.0
# Revisit in the near future.
algorithms["Diffrax: Kvaerno5()"] = solver_diffrax(solver=diffrax.Kvaerno5())
else:
print("\nSkipped Diffrax.\n")

# Compute a reference solution
reference = solver_probdiffeq(num_derivatives=4)(1e-10)
reference = solver_scipy(method="Radau")(0.1 * tolerances[-1])
precision_fun = rmse_absolute(reference)

# Compute all work-precision diagrams
Expand All @@ -100,6 +94,7 @@ def main(start=2.0, stop=8.0, step=1.0, repeats=2, use_diffrax: bool = False):
def solve_ivp_once():
"""Compute plotting-values for the IVP."""

@numba.jit(nopython=True)
def vf_scipy(_t, u):
"""Van-der-Pol dynamics as a first-order differential equation."""
return np.asarray([u[1], 1e5 * ((1.0 - u[0] ** 2) * u[1] - u[0])])
Expand Down Expand Up @@ -151,62 +146,30 @@ def param_to_solution(tol):
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0, du0), num=num_derivatives - 1)

init, ibm, ssm = probdiffeq.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
ts0_or_ts1 = probdiffeq.constraint_root_ts1(root, ssm=ssm)
ts = probdiffeq.constraint_root_ts1(root, ssm=ssm)
strategy = probdiffeq.strategy_filter(ssm=ssm)

solver = probdiffeq.solver_dynamic(
strategy=strategy, prior=ibm, constraint=ts0_or_ts1, ssm=ssm
strategy=strategy, prior=ibm, constraint=ts, ssm=ssm
)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
error = probdiffeq.error_residual_std(constraint=ts, prior=ibm, ssm=ssm)

dt0 = ivpsolve.dt0(vf_auto, (u0, du0))
control = ivpsolve.control_proportional_integral()

solve = ivpsolve.solve_adaptive_terminal_values(
solver=solver, errorest=errorest, control=control, clip_dt=True
solver=solver, error=error, control=control, clip_dt=True
)
solution = solve(init, t0=t0, t1=t1, dt0=dt0, atol=1e-3 * tol, rtol=tol)
return jax.block_until_ready(solution.u.mean[0])

return param_to_solution


def solver_diffrax(*, solver) -> Callable:
"""Construct a solver that wraps Diffrax' solution routines."""

@diffrax.ODETerm
@jax.jit
def vf_diffrax(_t, u, _args):
"""Van-der-Pol dynamics as a first-order differential equation."""
return jnp.asarray([u[1], 1e5 * ((1.0 - u[0] ** 2) * u[1] - u[0])])

t0, t1 = 0.0, 3.0
u0 = jnp.concatenate((jnp.atleast_1d(2.0), jnp.atleast_1d(0.0)))
t0, t1 = (0.0, 6.3)

@jax.jit
def param_to_solution(tol):
controller = diffrax.PIDController(atol=1e-3 * tol, rtol=tol)
saveat = diffrax.SaveAt(t0=False, t1=True, ts=None)
solution = diffrax.diffeqsolve(
vf_diffrax,
y0=u0,
t0=t0,
t1=t1,
saveat=saveat,
stepsize_controller=controller,
dt0=None,
max_steps=10_000,
solver=solver,
)
return jax.block_until_ready(solution.ys[0, 0])

return param_to_solution


def solver_scipy(method: str) -> Callable:
"""Construct a solver that wraps SciPy's solution routines."""

@numba.jit(nopython=True)
def vf_scipy(_t, u):
"""Van-der-Pol dynamics as a first-order differential equation."""
return np.asarray([u[1], 1e5 * ((1.0 - u[0] ** 2) * u[1] - u[0])])
Expand Down
4 changes: 2 additions & 2 deletions docs/examples_quickstart/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ def vf(y, /, *, t): # noqa: ARG001
ts = probdiffeq.constraint_ode_ts1(vf, ssm=ssm)
strategy = probdiffeq.strategy_filter(ssm=ssm)
solver = probdiffeq.solver_mle(ssm=ssm, strategy=strategy, prior=iwp, constraint=ts)
errorest = probdiffeq.errorest_local_residual_cached(prior=iwp, ssm=ssm)
error = probdiffeq.error_residual_std(constraint=ts, prior=iwp, ssm=ssm)


# Solve the ODE
# To all users: Try different solution routines.
save_at = jnp.linspace(t0, t1, num=100, endpoint=True)
solve = ivpsolve.solve_adaptive_save_at(solver=solver, errorest=errorest)
solve = ivpsolve.solve_adaptive_save_at(solver=solver, error=error)
solution = jax.jit(solve)(init, save_at, atol=1e-3, rtol=1e-3)


Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ nav:
- examples_basic/second_order_problems.ipynb
- examples_basic/taylor_coefficients.ipynb
- examples_basic/custom_information_operator.ipynb
- examples_basic/solve_pde.ipynb
- EXAMPLES | ADVANCED:
- examples_advanced/parameter_estimation_optax.ipynb
- examples_advanced/parameter_estimation_blackjax.ipynb
- examples_advanced/neural_ode.ipynb
- examples_advanced/equinox_while_loop.ipynb
- examples_advanced/solve_pde.ipynb
- EXAMPLES | BENCHMARKS:
- examples_benchmarks/work-precision-lotka-volterra.ipynb
- examples_benchmarks/work-precision-pleiades.ipynb
Expand Down
Loading