Skip to content

Commit efb91ea

Browse files
authored
Improve error estimation (API *and* implementation) (#852)
* Merge local-residual error estimates and add the suffix *std * Avoid testing the product space of configs in test_adaptive_save_at to save minutes of test time * Update the benchmarks to the new API * Move PDE example to examples_basic because it does not have an external dependency * Fix some tests * Fix remaining tests * Explain the improved normalisation * Update the VdP benchmarK
1 parent 6e7acf0 commit efb91ea

31 files changed

Lines changed: 263 additions & 324 deletions

docs/examples_advanced/equinox_while_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ def vf(y, /, *, t): # noqa: ARG001
4747

4848
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
4949
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm)
50-
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
50+
error = probdiffeq.error_residual_std(constraint=ts0, prior=ibm, ssm=ssm)
5151
solve_adaptive = ivpsolve.solve_adaptive_terminal_values(
52-
solver=solver, errorest=errorest, while_loop=while_loop
52+
solver=solver, error=error, while_loop=while_loop
5353
)
5454

5555
def simulate(init_val):

docs/examples_advanced/parameter_estimation_blackjax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ def solve_adaptive(theta, *, save_at):
201201
# Create a probabilistic solver
202202
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (theta,), num=2)
203203
init, _ibm, _ssm = probdiffeq.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
204-
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
205-
solve = ivpsolve.solve_adaptive_save_at(solver=solver, errorest=errorest)
204+
error = probdiffeq.error_residual_std(constraint=ts0, prior=ibm, ssm=ssm)
205+
solve = ivpsolve.solve_adaptive_save_at(solver=solver, error=error)
206206
return solve(init, save_at=save_at, dt0=0.1, atol=1e-4, rtol=1e-2)
207207

208208

docs/examples_basic/conditioning_on_zero_residual.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ def vector_field(y, /, *, t):
8989
ts1 = probdiffeq.constraint_ode_ts1(vector_field, ssm=ssm)
9090
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
9191
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm)
92-
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
92+
error = probdiffeq.error_residual_std(constraint=ts1, prior=ibm, ssm=ssm)
9393

9494
dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,))
95-
solve = ivpsolve.solve_adaptive_save_at(solver=solver, errorest=errorest)
95+
solve = ivpsolve.solve_adaptive_save_at(solver=solver, error=error)
9696
sol = solve(init, save_at=ts, dt0=dt0, atol=1e-1, rtol=1e-1)
9797
markov_seq_posterior = sol.solution_full
9898

docs/examples_basic/custom_information_operator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def hamiltonian_2nd(u, du):
9999
solver_1st = probdiffeq.solver_mle(
100100
strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
101101
)
102-
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
103-
solve = ivpsolve.solve_adaptive_save_at(solver=solver_1st, errorest=errorest)
102+
error = probdiffeq.error_residual_std(constraint=ts1, prior=ibm, ssm=ssm)
103+
solve = ivpsolve.solve_adaptive_save_at(solver=solver_1st, error=error)
104104

105105
# -
106106

@@ -165,11 +165,11 @@ def root(u, du, ddu, /, *, t):
165165

166166
# +
167167

168-
error_norm = probdiffeq.errorest_error_norm_rms_then_scale()
169-
errorest = probdiffeq.errorest_local_residual_cached(
170-
prior=ibm, ssm=ssm, error_norm=error_norm
168+
error_norm = probdiffeq.error_norm_rms_then_scale()
169+
error = probdiffeq.error_residual_std(
170+
constraint=ts1, prior=ibm, ssm=ssm, error_norm=error_norm
171171
)
172-
solve = ivpsolve.solve_adaptive_save_at(solver=solver_2nd, errorest=errorest)
172+
solve = ivpsolve.solve_adaptive_save_at(solver=solver_2nd, error=error)
173173

174174
sol_2 = jax.jit(solve)(init, save_at=save_at, atol=1e-3, rtol=1e-1)
175175
ham_2 = jax.vmap(hamiltonian_2nd)(sol_2.u.mean[0], sol_2.u.mean[1])

docs/examples_basic/posterior_uncertainties.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def vf(y, /, *, t): # noqa: ARG001
6565
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
6666

6767
solver = probdiffeq.solver_mle(strategy=strategy, prior=ibm, constraint=ts, ssm=ssm)
68-
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
69-
solve = ivpsolve.solve_adaptive_save_at(solver=solver, errorest=errorest)
68+
error = probdiffeq.error_residual_std(constraint=ts, prior=ibm, ssm=ssm)
69+
solve = ivpsolve.solve_adaptive_save_at(solver=solver, error=error)
7070

7171

7272
# -

docs/examples_basic/second_order_problems.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def vf_1(y, /, *, t):
6262
solver_1st = probdiffeq.solver_mle(
6363
strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm
6464
)
65-
errorest_1st = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
65+
error_1st = probdiffeq.error_residual_std(constraint=ts0, prior=ibm, ssm=ssm)
6666

6767

6868
# -
@@ -71,7 +71,7 @@ def vf_1(y, /, *, t):
7171

7272

7373
save_at = jnp.linspace(t0, t1, endpoint=True, num=250)
74-
solve = ivpsolve.solve_adaptive_save_at(solver=solver_1st, errorest=errorest_1st)
74+
solve = ivpsolve.solve_adaptive_save_at(solver=solver_1st, error=error_1st)
7575
solution = jax.jit(solve)(init, save_at=save_at, atol=1e-5, rtol=1e-5)
7676
plt.plot(solution.u.mean[0][:, 0], solution.u.mean[0][:, 1], marker=".")
7777
plt.show()
@@ -104,15 +104,15 @@ def vf_2(y, dy, /, *, t):
104104
solver_2nd = probdiffeq.solver_mle(
105105
strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm
106106
)
107-
errorest_2nd = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
107+
error_2nd = probdiffeq.error_residual_std(constraint=ts0, prior=ibm, ssm=ssm)
108108

109109
# -
110110

111111

112112
# +
113113

114114

115-
solve = ivpsolve.solve_adaptive_save_at(solver=solver_2nd, errorest=errorest_2nd)
115+
solve = ivpsolve.solve_adaptive_save_at(solver=solver_2nd, error=error_2nd)
116116
solution = jax.jit(solve)(init, save_at=save_at, atol=1e-5, rtol=1e-5)
117117

118118
plt.plot(solution.u.mean[0][:, 0], solution.u.mean[0][:, 1], marker=".")
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ def vf(y, /, *, t): # noqa: ARG001
5252
solver = probdiffeq.solver_dynamic(
5353
ssm=ssm, strategy=strategy, prior=ibm, constraint=ts
5454
)
55-
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
55+
error = probdiffeq.error_residual_std(constraint=ts, prior=ibm, ssm=ssm)
5656

5757
# Solve the ODE
5858
save_at = jnp.linspace(t0, t1, num=5, endpoint=True)
59-
simulate = simulator(save_at=save_at, errorest=errorest, solver=solver)
59+
simulate = simulator(save_at=save_at, error=error, solver=solver)
6060
(u, u_std) = simulate(init)
6161

6262
_fig, axes = plt.subplots(
@@ -86,12 +86,12 @@ def vf(y, /, *, t): # noqa: ARG001
8686
# +
8787

8888

89-
def simulator(save_at, errorest, solver):
89+
def simulator(save_at, error, solver):
9090
"""Simulate a PDE."""
9191

9292
@jax.jit
9393
def solve(init):
94-
solve = ivpsolve.solve_adaptive_save_at(errorest=errorest, solver=solver)
94+
solve = ivpsolve.solve_adaptive_save_at(error=error, solver=solver)
9595
solution = solve(init, save_at=save_at, atol=1e-4, rtol=1e-2)
9696
return (solution.u.mean[0], solution.u.std[0])
9797

docs/examples_basic/taylor_coefficients.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def solve(tc):
6868
strategy=strategy, prior=prior, constraint=ts0, ssm=ssm
6969
)
7070
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
71-
errorest = probdiffeq.errorest_local_residual_cached(prior=prior, ssm=ssm)
72-
solve = ivpsolve.solve_adaptive_save_at(solver=solver, errorest=errorest)
71+
error = probdiffeq.error_residual_std(constraint=ts0, prior=prior, ssm=ssm)
72+
solve = ivpsolve.solve_adaptive_save_at(solver=solver, error=error)
7373
return solve(init, save_at=ts, atol=1e-2, rtol=1e-2)
7474

7575

docs/examples_benchmarks/convergence-rates-lotka-volterra.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,13 @@ def param_to_solution(tol):
181181
# Build a solver
182182
init, ibm, ssm = probdiffeq.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
183183
strategy = probdiffeq.strategy_filter(ssm=ssm)
184-
corr = probdiffeq.constraint_ode_ts1(vf_probdiffeq, ssm=ssm)
185-
solver = probdiffeq.solver(
186-
strategy=strategy, prior=ibm, constraint=corr, ssm=ssm
187-
)
188-
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
184+
ts = probdiffeq.constraint_ode_ts1(vf_probdiffeq, ssm=ssm)
185+
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts, ssm=ssm)
186+
error = probdiffeq.error_residual_std(constraint=ts, prior=ibm, ssm=ssm)
189187

190188
control = ivpsolve.control_proportional_integral()
191189
solve = ivpsolve.solve_adaptive_terminal_values(
192-
solver=solver, errorest=errorest, control=control
190+
solver=solver, error=error, control=control
193191
)
194192

195193
# Solve

docs/examples_benchmarks/work-precision-hires.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ def param_to_solution(tol):
170170
solver = probdiffeq.solver_dynamic(
171171
strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
172172
)
173-
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
173+
error = probdiffeq.error_residual_std(constraint=ts1, prior=ibm, ssm=ssm)
174174

175175
control = ivpsolve.control_proportional_integral()
176176
solve = ivpsolve.solve_adaptive_terminal_values(
177-
solver=solver, clip_dt=True, control=control, errorest=errorest
177+
solver=solver, clip_dt=True, control=control, error=error
178178
)
179179

180180
# Solve

0 commit comments

Comments
 (0)