@@ -47,7 +47,7 @@ def vf_1(y, t): # noqa: ARG001
4747init , ibm , ssm = ivpsolvers .prior_wiener_integrated (
4848 tcoeffs , output_scale = 1.0 , ssm_fact = "isotropic"
4949)
50- ts0 = ivpsolvers .correction_ts0 (ssm = ssm )
50+ ts0 = ivpsolvers .correction_ts0 (vf_1 , ssm = ssm )
5151strategy = ivpsolvers .strategy_filter (ssm = ssm )
5252solver_1st = ivpsolvers .solver_mle (strategy , prior = ibm , correction = ts0 , ssm = ssm )
5353adaptive_solver_1st = ivpsolvers .adaptive (solver_1st , atol = 1e-5 , rtol = 1e-5 , ssm = ssm )
@@ -56,7 +56,7 @@ def vf_1(y, t): # noqa: ARG001
5656# -
5757
5858solution = ivpsolve .solve_adaptive_save_every_step (
59- vf_1 , init , t0 = t0 , t1 = t1 , dt0 = 0.1 , adaptive_solver = adaptive_solver_1st , ssm = ssm
59+ init , t0 = t0 , t1 = t1 , dt0 = 0.1 , adaptive_solver = adaptive_solver_1st , ssm = ssm
6060)
6161
6262norm = jnp .linalg .norm ((solution .u [0 ][- 1 ] - u0 ) / jnp .abs (1.0 + u0 ))
@@ -82,15 +82,15 @@ def vf_2(y, dy, t): # noqa: ARG001
8282init , ibm , ssm = ivpsolvers .prior_wiener_integrated (
8383 tcoeffs , output_scale = 1.0 , ssm_fact = "isotropic"
8484)
85- ts0 = ivpsolvers .correction_ts0 (ode_order = 2 , ssm = ssm )
85+ ts0 = ivpsolvers .correction_ts0 (vf_2 , ode_order = 2 , ssm = ssm )
8686strategy = ivpsolvers .strategy_filter (ssm = ssm )
8787solver_2nd = ivpsolvers .solver_mle (strategy , prior = ibm , correction = ts0 , ssm = ssm )
8888adaptive_solver_2nd = ivpsolvers .adaptive (solver_2nd , atol = 1e-5 , rtol = 1e-5 , ssm = ssm )
8989
9090# -
9191
9292solution = ivpsolve .solve_adaptive_save_every_step (
93- vf_2 , init , t0 = t0 , t1 = t1 , dt0 = 0.1 , adaptive_solver = adaptive_solver_2nd , ssm = ssm
93+ init , t0 = t0 , t1 = t1 , dt0 = 0.1 , adaptive_solver = adaptive_solver_2nd , ssm = ssm
9494)
9595
9696norm = jnp .linalg .norm ((solution .u [0 ][- 1 , ...] - u0 ) / jnp .abs (1.0 + u0 ))
0 commit comments