Skip to content

Commit c27a024

Browse files
committed
Update example notebooks
1 parent cbc7c93 commit c27a024

10 files changed

Lines changed: 25 additions & 39 deletions

docs/examples_advanced/equinox_while_loop.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def vf(y, *, t): # noqa: ARG001
6262

6363
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
6464
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
65-
ts0 = ivpsolvers.correction_ts0(ode_order=1, ssm=ssm)
65+
ts0 = ivpsolvers.correction_ts0(vf, ode_order=1, ssm=ssm)
6666

6767
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
6868
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
@@ -71,13 +71,7 @@ def vf(y, *, t): # noqa: ARG001
7171
def simulate(init_val):
7272
"""Evaluate the parameter-to-solution function."""
7373
sol = ivpsolve.solve_adaptive_terminal_values(
74-
vf,
75-
init_val,
76-
t0=t0,
77-
t1=t1,
78-
dt0=0.1,
79-
adaptive_solver=adaptive_solver,
80-
ssm=ssm,
74+
init_val, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
8175
)
8276

8377
# Any scalar function of the IVP solution would do

docs/examples_advanced/neural_ode.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,18 +152,12 @@ def loss(
152152
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
153153
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
154154
)
155-
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
155+
ts0 = ivpsolvers.correction_ts0(lambda *a, **kw: vf(*a, **kw, p=p), ssm=ssm)
156156
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
157157
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
158158

159159
# Solve
160-
sol = ivpsolve.solve_fixed_grid(
161-
lambda *a, **kw: vf(*a, **kw, p=p),
162-
init,
163-
grid=grid,
164-
solver=solver_ts0,
165-
ssm=ssm,
166-
)
160+
sol = ivpsolve.solve_fixed_grid(init, grid=grid, solver=solver_ts0, ssm=ssm)
167161

168162
# Evaluate loss
169163
marginal_likelihood = stats.log_marginal_likelihood(

docs/examples_advanced/parameter_estimation_blackjax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,10 @@ def solve_fixed(theta, *, ts):
186186
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
187187
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
188188
)
189-
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
189+
ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm)
190190
strategy = ivpsolvers.strategy_filter(ssm=ssm)
191191
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
192-
return ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver, ssm=ssm)
192+
return ivpsolve.solve_fixed_grid(init, grid=ts, solver=solver, ssm=ssm)
193193

194194

195195
@jax.jit
@@ -201,12 +201,12 @@ def solve_adaptive(theta, *, save_at):
201201
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
202202
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
203203
)
204-
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
204+
ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm)
205205
strategy = ivpsolvers.strategy_filter(ssm=ssm)
206206
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
207207
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
208208
return ivpsolve.solve_adaptive_save_at(
209-
vf, init, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
209+
init, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
210210
)
211211

212212

docs/examples_advanced/parameter_estimation_optax.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,10 @@ def solve(p):
6262
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
6363
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
6464
)
65-
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
65+
ts0 = ivpsolvers.correction_ts0(lambda y, t: vf(y, t, p=p), ssm=ssm)
6666
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
6767
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
68-
return ivpsolve.solve_fixed_grid(
69-
lambda y, t: vf(y, t, p=p), init, grid=ts, solver=solver, ssm=ssm
70-
)
68+
return ivpsolve.solve_fixed_grid(init, grid=ts, solver=solver, ssm=ssm)
7169

7270

7371
parameter_true = f_args + 0.05

docs/examples_basic/conditioning_on_zero_residual.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ def vector_field(y, t): # noqa: ARG001
7474
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
7575
tcoeffs, output_scale=1.0, ssm_fact="dense"
7676
)
77-
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
77+
ts1 = ivpsolvers.correction_ts1(vector_field, ssm=ssm)
7878
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
7979
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts1, ssm=ssm)
8080
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-1, rtol=1e-2, ssm=ssm)
8181

8282
dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,))
8383
sol = ivpsolve.solve_adaptive_save_at(
84-
vector_field, init, save_at=ts, dt0=1.0, adaptive_solver=adaptive_solver, ssm=ssm
84+
init, save_at=ts, dt0=1.0, adaptive_solver=adaptive_solver, ssm=ssm
8585
)
8686
markov_seq_posterior = stats.markov_select_terminal(sol.posterior)
8787

docs/examples_basic/dynamic_output_scales.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def vf(*ys, t): # noqa: ARG001
6767
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
6868
tcoeffs, output_scale=1.0, ssm_fact="dense"
6969
)
70-
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
70+
ts1 = ivpsolvers.correction_ts1(vf, ssm=ssm)
7171
strategy = ivpsolvers.strategy_filter(ssm=ssm)
7272
dynamic = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm)
7373
mle = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts1, ssm=ssm)
@@ -79,8 +79,8 @@ def vf(*ys, t): # noqa: ARG001
7979
ts = jnp.linspace(t0, t1, num=num_pts, endpoint=True)
8080

8181

82-
solution_dynamic = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=dynamic, ssm=ssm)
83-
solution_mle = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=mle, ssm=ssm)
82+
solution_dynamic = ivpsolve.solve_fixed_grid(init, grid=ts, solver=dynamic, ssm=ssm)
83+
solution_mle = ivpsolve.solve_fixed_grid(init, grid=ts, solver=mle, ssm=ssm)
8484
# -
8585

8686
# Plot the solution.

docs/examples_basic/posterior_uncertainties.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ def vf(y, *, t): # noqa: ARG001
4343
# To all users: Try replacing the fixedpoint-smoother with a filter!
4444
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=3)
4545
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
46-
ts = ivpsolvers.correction_ts1(ssm=ssm)
46+
ts = ivpsolvers.correction_ts1(vf, ssm=ssm)
4747
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
4848
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts, ssm=ssm)
4949
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-1, rtol=1e-1, ssm=ssm)
5050

5151
# Solve the ODE
5252
ts = jnp.linspace(t0, t1, endpoint=True, num=50)
5353
sol = ivpsolve.solve_adaptive_save_at(
54-
vf, init, save_at=ts, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
54+
init, save_at=ts, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
5555
)
5656

5757
# Calibrate

docs/examples_basic/second_order_problems.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def vf_1(y, t): # noqa: ARG001
4747
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
4848
tcoeffs, output_scale=1.0, ssm_fact="isotropic"
4949
)
50-
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
50+
ts0 = ivpsolvers.correction_ts0(vf_1, ssm=ssm)
5151
strategy = ivpsolvers.strategy_filter(ssm=ssm)
5252
solver_1st = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)
5353
adaptive_solver_1st = ivpsolvers.adaptive(solver_1st, atol=1e-5, rtol=1e-5, ssm=ssm)
@@ -56,7 +56,7 @@ def vf_1(y, t): # noqa: ARG001
5656
# -
5757

5858
solution = ivpsolve.solve_adaptive_save_every_step(
59-
vf_1, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_1st, ssm=ssm
59+
init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_1st, ssm=ssm
6060
)
6161

6262
norm = jnp.linalg.norm((solution.u[0][-1] - u0) / jnp.abs(1.0 + u0))
@@ -82,15 +82,15 @@ def vf_2(y, dy, t): # noqa: ARG001
8282
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
8383
tcoeffs, output_scale=1.0, ssm_fact="isotropic"
8484
)
85-
ts0 = ivpsolvers.correction_ts0(ode_order=2, ssm=ssm)
85+
ts0 = ivpsolvers.correction_ts0(vf_2, ode_order=2, ssm=ssm)
8686
strategy = ivpsolvers.strategy_filter(ssm=ssm)
8787
solver_2nd = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)
8888
adaptive_solver_2nd = ivpsolvers.adaptive(solver_2nd, atol=1e-5, rtol=1e-5, ssm=ssm)
8989

9090
# -
9191

9292
solution = ivpsolve.solve_adaptive_save_every_step(
93-
vf_2, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_2nd, ssm=ssm
93+
init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_2nd, ssm=ssm
9494
)
9595

9696
norm = jnp.linalg.norm((solution.u[0][-1, ...] - u0) / jnp.abs(1.0 + u0))

docs/examples_basic/taylor_coefficients.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,13 @@ def vf(*y, t): # noqa: ARG001
5858
def solve(tc):
5959
"""Solve the ODE."""
6060
init, prior, ssm = ivpsolvers.prior_wiener_integrated(tc, ssm_fact="dense")
61-
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
61+
ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm)
6262
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
6363
solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm)
6464
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
6565
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)
6666
return ivpsolve.solve_adaptive_save_at(
67-
vf, init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
67+
init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
6868
)
6969

7070

docs/examples_quickstart/quickstart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def vf(y, *, t): # noqa: ARG001
4343

4444

4545
# Build a solver
46-
ts = ivpsolvers.correction_ts1(ssm=ssm, ode_order=1)
46+
ts = ivpsolvers.correction_ts1(vf, ssm=ssm, ode_order=1)
4747
strategy = ivpsolvers.strategy_filter(ssm=ssm)
4848
solver = ivpsolvers.solver_mle(ssm=ssm, strategy=strategy, prior=ibm, correction=ts)
4949
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
@@ -52,7 +52,7 @@ def vf(y, *, t): # noqa: ARG001
5252
# Solve the ODE
5353
# To all users: Try different solution routines.
5454
solution = ivpsolve.solve_adaptive_save_every_step(
55-
vf, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
55+
init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
5656
)
5757

5858
# Look at the solution

0 commit comments

Comments
 (0)