Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified docs/benchmarks/hires/plot_ts.npy
Binary file not shown.
Binary file modified docs/benchmarks/hires/plot_ys.npy
Binary file not shown.
Binary file modified docs/benchmarks/hires/results.npy
Binary file not shown.
10 changes: 2 additions & 8 deletions docs/benchmarks/hires/run_hires.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def param_to_solution(tol):
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives)
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
ts1 = ivpsolvers.correction_ts1(vf_probdiffeq, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm)
control = ivpsolvers.control_proportional_integral()
Expand All @@ -101,13 +101,7 @@ def param_to_solution(tol):
# Solve
dt0 = ivpsolve.dt0(vf_auto, (u0,))
solution = ivpsolve.solve_adaptive_terminal_values(
vf_probdiffeq,
init,
t0=t0,
t1=t1,
dt0=dt0,
adaptive_solver=adaptive_solver,
ssm=ssm,
init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm
)

# Return the terminal value
Expand Down
Binary file modified docs/benchmarks/lotkavolterra/results.npy
Binary file not shown.
10 changes: 2 additions & 8 deletions docs/benchmarks/lotkavolterra/run_lotkavolterra.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def param_to_solution(tol):
tcoeffs, ssm_fact=implementation
)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
corr = correction(ssm=ssm)
corr = correction(vf_probdiffeq, ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=corr, ssm=ssm)
control = ivpsolvers.control_proportional_integral()
adaptive_solver = ivpsolvers.adaptive(
Expand All @@ -94,13 +94,7 @@ def param_to_solution(tol):
# Solve
dt0 = ivpsolve.dt0(vf_auto, (u0,))
solution = ivpsolve.solve_adaptive_terminal_values(
vf_probdiffeq,
init,
t0=t0,
t1=t1,
dt0=dt0,
adaptive_solver=adaptive_solver,
ssm=ssm,
init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm
)

# Return the terminal value
Expand Down
Binary file modified docs/benchmarks/pleiades/plot_ts.npy
Binary file not shown.
Binary file modified docs/benchmarks/pleiades/plot_ys.npy
Binary file not shown.
Binary file modified docs/benchmarks/pleiades/results.npy
Binary file not shown.
10 changes: 2 additions & 8 deletions docs/benchmarks/pleiades/run_pleiades.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def param_to_solution(tol):
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, ssm_fact="isotropic"
)
ts0_or_ts1 = correction_fun(ssm=ssm, ode_order=2)
ts0_or_ts1 = correction_fun(vf_probdiffeq, ssm=ssm, ode_order=2)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver_dynamic(
strategy, prior=ibm, correction=ts0_or_ts1, ssm=ssm
Expand All @@ -115,13 +115,7 @@ def param_to_solution(tol):
# Solve
dt0 = ivpsolve.dt0(vf_auto, (u0, du0))
solution = ivpsolve.solve_adaptive_terminal_values(
vf_probdiffeq,
init,
t0=t0,
t1=t1,
dt0=dt0,
adaptive_solver=adaptive_solver,
ssm=ssm,
init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm
)

# Return the terminal value
Expand Down
Binary file modified docs/benchmarks/taylor_fitzhughnagumo/results.npy
Binary file not shown.
Binary file modified docs/benchmarks/taylor_node/results.npy
Binary file not shown.
Binary file modified docs/benchmarks/taylor_pleiades/results.npy
Binary file not shown.
Binary file modified docs/benchmarks/vanderpol/plot_ts.npy
Binary file not shown.
Binary file modified docs/benchmarks/vanderpol/plot_ys.npy
Binary file not shown.
Binary file modified docs/benchmarks/vanderpol/results.npy
Binary file not shown.
10 changes: 2 additions & 8 deletions docs/benchmarks/vanderpol/run_vanderpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def param_to_solution(tol):
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0, du0), num=num_derivatives - 1)

init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
ts0_or_ts1 = ivpsolvers.correction_ts1(ode_order=2, ssm=ssm)
ts0_or_ts1 = ivpsolvers.correction_ts1(vf_probdiffeq, ode_order=2, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)

solver = ivpsolvers.solver_dynamic(
Expand All @@ -96,13 +96,7 @@ def param_to_solution(tol):
# Solve
dt0 = ivpsolve.dt0(vf_auto, (u0, du0))
solution = ivpsolve.solve_adaptive_terminal_values(
vf_probdiffeq,
init,
t0=t0,
t1=t1,
dt0=dt0,
adaptive_solver=adaptive_solver,
ssm=ssm,
init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm
)

# Return the terminal value
Expand Down
10 changes: 2 additions & 8 deletions docs/examples_advanced/equinox_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def vf(y, *, t): # noqa: ARG001

tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ode_order=1, ssm=ssm)
ts0 = ivpsolvers.correction_ts0(vf, ode_order=1, ssm=ssm)

strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
Expand All @@ -71,13 +71,7 @@ def vf(y, *, t): # noqa: ARG001
def simulate(init_val):
"""Evaluate the parameter-to-solution function."""
sol = ivpsolve.solve_adaptive_terminal_values(
vf,
init_val,
t0=t0,
t1=t1,
dt0=0.1,
adaptive_solver=adaptive_solver,
ssm=ssm,
init_val, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
)

# Any scalar function of the IVP solution would do
Expand Down
10 changes: 2 additions & 8 deletions docs/examples_advanced/neural_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,18 +152,12 @@ def loss(
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
ts0 = ivpsolvers.correction_ts0(lambda *a, **kw: vf(*a, **kw, p=p), ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)

# Solve
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=p),
init,
grid=grid,
solver=solver_ts0,
ssm=ssm,
)
sol = ivpsolve.solve_fixed_grid(init, grid=grid, solver=solver_ts0, ssm=ssm)

# Evaluate loss
marginal_likelihood = stats.log_marginal_likelihood(
Expand Down
8 changes: 4 additions & 4 deletions docs/examples_advanced/parameter_estimation_blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,10 @@ def solve_fixed(theta, *, ts):
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
return ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver, ssm=ssm)
return ivpsolve.solve_fixed_grid(init, grid=ts, solver=solver, ssm=ssm)


@jax.jit
Expand All @@ -201,12 +201,12 @@ def solve_adaptive(theta, *, save_at):
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
return ivpsolve.solve_adaptive_save_at(
vf, init, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
init, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
)


Expand Down
6 changes: 2 additions & 4 deletions docs/examples_advanced/parameter_estimation_optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@ def solve(p):
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
ts0 = ivpsolvers.correction_ts0(lambda y, t: vf(y, t, p=p), ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
return ivpsolve.solve_fixed_grid(
lambda y, t: vf(y, t, p=p), init, grid=ts, solver=solver, ssm=ssm
)
return ivpsolve.solve_fixed_grid(init, grid=ts, solver=solver, ssm=ssm)


parameter_true = f_args + 0.05
Expand Down
4 changes: 2 additions & 2 deletions docs/examples_basic/conditioning_on_zero_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ def vector_field(y, t): # noqa: ARG001
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=1.0, ssm_fact="dense"
)
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
ts1 = ivpsolvers.correction_ts1(vector_field, ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts1, ssm=ssm)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-1, rtol=1e-2, ssm=ssm)

dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,))
sol = ivpsolve.solve_adaptive_save_at(
vector_field, init, save_at=ts, dt0=1.0, adaptive_solver=adaptive_solver, ssm=ssm
init, save_at=ts, dt0=1.0, adaptive_solver=adaptive_solver, ssm=ssm
)
markov_seq_posterior = stats.markov_select_terminal(sol.posterior)

Expand Down
6 changes: 3 additions & 3 deletions docs/examples_basic/dynamic_output_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def vf(*ys, t): # noqa: ARG001
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=1.0, ssm_fact="dense"
)
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
ts1 = ivpsolvers.correction_ts1(vf, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
dynamic = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm)
mle = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts1, ssm=ssm)
Expand All @@ -79,8 +79,8 @@ def vf(*ys, t): # noqa: ARG001
ts = jnp.linspace(t0, t1, num=num_pts, endpoint=True)


solution_dynamic = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=dynamic, ssm=ssm)
solution_mle = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=mle, ssm=ssm)
solution_dynamic = ivpsolve.solve_fixed_grid(init, grid=ts, solver=dynamic, ssm=ssm)
solution_mle = ivpsolve.solve_fixed_grid(init, grid=ts, solver=mle, ssm=ssm)
# -

# Plot the solution.
Expand Down
4 changes: 2 additions & 2 deletions docs/examples_basic/posterior_uncertainties.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def vf(y, *, t): # noqa: ARG001
# To all users: Try replacing the fixedpoint-smoother with a filter!
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=3)
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
ts = ivpsolvers.correction_ts1(ssm=ssm)
ts = ivpsolvers.correction_ts1(vf, ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts, ssm=ssm)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-1, rtol=1e-1, ssm=ssm)

# Solve the ODE
ts = jnp.linspace(t0, t1, endpoint=True, num=50)
sol = ivpsolve.solve_adaptive_save_at(
vf, init, save_at=ts, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
init, save_at=ts, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
)

# Calibrate
Expand Down
8 changes: 4 additions & 4 deletions docs/examples_basic/second_order_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def vf_1(y, t): # noqa: ARG001
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=1.0, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
ts0 = ivpsolvers.correction_ts0(vf_1, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver_1st = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver_1st = ivpsolvers.adaptive(solver_1st, atol=1e-5, rtol=1e-5, ssm=ssm)
Expand All @@ -56,7 +56,7 @@ def vf_1(y, t): # noqa: ARG001
# -

solution = ivpsolve.solve_adaptive_save_every_step(
vf_1, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_1st, ssm=ssm
init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_1st, ssm=ssm
)

norm = jnp.linalg.norm((solution.u[0][-1] - u0) / jnp.abs(1.0 + u0))
Expand All @@ -82,15 +82,15 @@ def vf_2(y, dy, t): # noqa: ARG001
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=1.0, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ode_order=2, ssm=ssm)
ts0 = ivpsolvers.correction_ts0(vf_2, ode_order=2, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver_2nd = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver_2nd = ivpsolvers.adaptive(solver_2nd, atol=1e-5, rtol=1e-5, ssm=ssm)

# -

solution = ivpsolve.solve_adaptive_save_every_step(
vf_2, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_2nd, ssm=ssm
init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_2nd, ssm=ssm
)

norm = jnp.linalg.norm((solution.u[0][-1, ...] - u0) / jnp.abs(1.0 + u0))
Expand Down
4 changes: 2 additions & 2 deletions docs/examples_basic/taylor_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ def vf(*y, t): # noqa: ARG001
def solve(tc):
"""Solve the ODE."""
init, prior, ssm = ivpsolvers.prior_wiener_integrated(tc, ssm_fact="dense")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm)
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)
return ivpsolve.solve_adaptive_save_at(
vf, init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
)


Expand Down
4 changes: 2 additions & 2 deletions docs/examples_quickstart/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def vf(y, *, t): # noqa: ARG001


# Build a solver
ts = ivpsolvers.correction_ts1(ssm=ssm, ode_order=1)
ts = ivpsolvers.correction_ts1(vf, ssm=ssm, ode_order=1)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver_mle(ssm=ssm, strategy=strategy, prior=ibm, correction=ts)
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
Expand All @@ -52,7 +52,7 @@ def vf(y, *, t): # noqa: ARG001
# Solve the ODE
# To all users: Try different solution routines.
solution = ivpsolve.solve_adaptive_save_every_step(
vf, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
)

# Look at the solution
Expand Down
9 changes: 9 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,14 @@ doc:
make benchmarks-plot-results
JUPYTER_PLATFORM_DIRS=1 mkdocs build

doc-serve:
# The readme is the landing page of the docs:
cp README.md docs/index.md
# Execute the examples manually and not via mkdocs-jupyter
# to gain clear error messages.
make example
make benchmarks-plot-results
JUPYTER_PLATFORM_DIRS=1 mkdocs serve

find-dead-code:
vulture . --ignore-names case*,fixture*,*jvp --exclude probdiffeq/_version.py
Loading