diff --git a/docs/benchmarks/hires/plot_ts.npy b/docs/benchmarks/hires/plot_ts.npy index a88a369c4..9ef0d5d71 100644 Binary files a/docs/benchmarks/hires/plot_ts.npy and b/docs/benchmarks/hires/plot_ts.npy differ diff --git a/docs/benchmarks/hires/plot_ys.npy b/docs/benchmarks/hires/plot_ys.npy index 4e84ca51f..9b2c924b3 100644 Binary files a/docs/benchmarks/hires/plot_ys.npy and b/docs/benchmarks/hires/plot_ys.npy differ diff --git a/docs/benchmarks/hires/results.npy b/docs/benchmarks/hires/results.npy index 4619fed3e..40b435eb3 100644 Binary files a/docs/benchmarks/hires/results.npy and b/docs/benchmarks/hires/results.npy differ diff --git a/docs/benchmarks/hires/run_hires.py b/docs/benchmarks/hires/run_hires.py index f9b2dce39..317829d6e 100644 --- a/docs/benchmarks/hires/run_hires.py +++ b/docs/benchmarks/hires/run_hires.py @@ -90,7 +90,7 @@ def param_to_solution(tol): vf_auto = functools.partial(vf_probdiffeq, t=t0) tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense") - ts1 = ivpsolvers.correction_ts1(ssm=ssm) + ts1 = ivpsolvers.correction_ts1(vf_probdiffeq, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm) control = ivpsolvers.control_proportional_integral() @@ -101,13 +101,7 @@ def param_to_solution(tol): # Solve dt0 = ivpsolve.dt0(vf_auto, (u0,)) solution = ivpsolve.solve_adaptive_terminal_values( - vf_probdiffeq, - init, - t0=t0, - t1=t1, - dt0=dt0, - adaptive_solver=adaptive_solver, - ssm=ssm, + init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm ) # Return the terminal value diff --git a/docs/benchmarks/lotkavolterra/results.npy b/docs/benchmarks/lotkavolterra/results.npy index f4cafd98d..7f7f9a586 100644 Binary files a/docs/benchmarks/lotkavolterra/results.npy and b/docs/benchmarks/lotkavolterra/results.npy differ diff --git a/docs/benchmarks/lotkavolterra/run_lotkavolterra.py b/docs/benchmarks/lotkavolterra/run_lotkavolterra.py index 36fa4654b..9e23f6a5a 100644 --- a/docs/benchmarks/lotkavolterra/run_lotkavolterra.py +++ b/docs/benchmarks/lotkavolterra/run_lotkavolterra.py @@ -84,7 +84,7 @@ def param_to_solution(tol): tcoeffs, ssm_fact=implementation ) strategy = ivpsolvers.strategy_filter(ssm=ssm) - corr = correction(ssm=ssm) + corr = correction(vf_probdiffeq, ssm=ssm) solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=corr, ssm=ssm) control = ivpsolvers.control_proportional_integral() adaptive_solver = ivpsolvers.adaptive( @@ -94,13 +94,7 @@ def param_to_solution(tol): # Solve dt0 = ivpsolve.dt0(vf_auto, (u0,)) solution = ivpsolve.solve_adaptive_terminal_values( - vf_probdiffeq, - init, - t0=t0, - t1=t1, - dt0=dt0, - adaptive_solver=adaptive_solver, - ssm=ssm, + init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm ) # Return the terminal value diff --git a/docs/benchmarks/pleiades/plot_ts.npy b/docs/benchmarks/pleiades/plot_ts.npy index 47c4f0111..db3b3d59b 100644 Binary files a/docs/benchmarks/pleiades/plot_ts.npy and b/docs/benchmarks/pleiades/plot_ts.npy differ diff --git a/docs/benchmarks/pleiades/plot_ys.npy b/docs/benchmarks/pleiades/plot_ys.npy index 624fa267b..be42ae476 100644 Binary files a/docs/benchmarks/pleiades/plot_ys.npy and b/docs/benchmarks/pleiades/plot_ys.npy differ diff --git a/docs/benchmarks/pleiades/results.npy b/docs/benchmarks/pleiades/results.npy index 6a96fd782..ef89e7f8a 100644 Binary files a/docs/benchmarks/pleiades/results.npy and b/docs/benchmarks/pleiades/results.npy differ diff --git a/docs/benchmarks/pleiades/run_pleiades.py b/docs/benchmarks/pleiades/run_pleiades.py index ac5f4f7cb..7baebba08 100644 --- a/docs/benchmarks/pleiades/run_pleiades.py +++ b/docs/benchmarks/pleiades/run_pleiades.py @@ -102,7 +102,7 @@ def param_to_solution(tol): init, ibm, ssm = ivpsolvers.prior_wiener_integrated( tcoeffs, ssm_fact="isotropic" ) - ts0_or_ts1 = correction_fun(ssm=ssm, ode_order=2) + ts0_or_ts1 = correction_fun(vf_probdiffeq, ssm=ssm, ode_order=2) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver_dynamic( strategy, prior=ibm, correction=ts0_or_ts1, ssm=ssm @@ -115,13 +115,7 @@ def param_to_solution(tol): # Solve dt0 = ivpsolve.dt0(vf_auto, (u0, du0)) solution = ivpsolve.solve_adaptive_terminal_values( - vf_probdiffeq, - init, - t0=t0, - t1=t1, - dt0=dt0, - adaptive_solver=adaptive_solver, - ssm=ssm, + init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm ) # Return the terminal value diff --git a/docs/benchmarks/taylor_fitzhughnagumo/results.npy b/docs/benchmarks/taylor_fitzhughnagumo/results.npy index 8cdd2801e..0acd68bc9 100644 Binary files a/docs/benchmarks/taylor_fitzhughnagumo/results.npy and b/docs/benchmarks/taylor_fitzhughnagumo/results.npy differ diff --git a/docs/benchmarks/taylor_node/results.npy b/docs/benchmarks/taylor_node/results.npy index 946d6aafc..884e65ee2 100644 Binary files a/docs/benchmarks/taylor_node/results.npy and b/docs/benchmarks/taylor_node/results.npy differ diff --git a/docs/benchmarks/taylor_pleiades/results.npy b/docs/benchmarks/taylor_pleiades/results.npy index d4f90364b..6c3c71d3a 100644 Binary files a/docs/benchmarks/taylor_pleiades/results.npy and b/docs/benchmarks/taylor_pleiades/results.npy differ diff --git a/docs/benchmarks/vanderpol/plot_ts.npy b/docs/benchmarks/vanderpol/plot_ts.npy index 2b015c365..491f284a2 100644 Binary files a/docs/benchmarks/vanderpol/plot_ts.npy and b/docs/benchmarks/vanderpol/plot_ts.npy differ diff --git a/docs/benchmarks/vanderpol/plot_ys.npy b/docs/benchmarks/vanderpol/plot_ys.npy index 6d0e1fd03..1bfecf88e 100644 Binary files a/docs/benchmarks/vanderpol/plot_ys.npy and b/docs/benchmarks/vanderpol/plot_ys.npy differ diff --git a/docs/benchmarks/vanderpol/results.npy b/docs/benchmarks/vanderpol/results.npy index 8b8cff46a..2007926e3 100644 Binary files a/docs/benchmarks/vanderpol/results.npy and b/docs/benchmarks/vanderpol/results.npy differ diff --git a/docs/benchmarks/vanderpol/run_vanderpol.py b/docs/benchmarks/vanderpol/run_vanderpol.py index 9e4b18b59..c4cf9c277 100644 --- a/docs/benchmarks/vanderpol/run_vanderpol.py +++ b/docs/benchmarks/vanderpol/run_vanderpol.py @@ -82,7 +82,7 @@ def param_to_solution(tol): tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0, du0), num=num_derivatives - 1) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense") - ts0_or_ts1 = ivpsolvers.correction_ts1(ode_order=2, ssm=ssm) + ts0_or_ts1 = ivpsolvers.correction_ts1(vf_probdiffeq, ode_order=2, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver_dynamic( @@ -96,13 +96,7 @@ def param_to_solution(tol): # Solve dt0 = ivpsolve.dt0(vf_auto, (u0, du0)) solution = ivpsolve.solve_adaptive_terminal_values( - vf_probdiffeq, - init, - t0=t0, - t1=t1, - dt0=dt0, - adaptive_solver=adaptive_solver, - ssm=ssm, + init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm ) # Return the terminal value diff --git a/docs/examples_advanced/equinox_while_loop.py b/docs/examples_advanced/equinox_while_loop.py index 2d07178df..4920cf4b1 100644 --- a/docs/examples_advanced/equinox_while_loop.py +++ b/docs/examples_advanced/equinox_while_loop.py @@ -62,7 +62,7 @@ def vf(y, *, t): # noqa: ARG001 tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic") - ts0 = ivpsolvers.correction_ts0(ode_order=1, ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ode_order=1, ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) @@ -71,13 +71,7 @@ def vf(y, *, t): # noqa: ARG001 def simulate(init_val): """Evaluate the parameter-to-solution function.""" sol = ivpsolve.solve_adaptive_terminal_values( - vf, - init_val, - t0=t0, - t1=t1, - dt0=0.1, - adaptive_solver=adaptive_solver, - ssm=ssm, + init_val, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm ) # Any scalar function of the IVP solution would do diff --git a/docs/examples_advanced/neural_ode.py b/docs/examples_advanced/neural_ode.py index 94db36b04..8160b5e21 100644 --- a/docs/examples_advanced/neural_ode.py +++ b/docs/examples_advanced/neural_ode.py @@ -152,18 +152,12 @@ def loss( init, ibm, ssm = ivpsolvers.prior_wiener_integrated( tcoeffs, output_scale=output_scale, ssm_fact="isotropic" ) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(lambda *a, **kw: vf(*a, **kw, p=p), ssm=ssm) strategy = ivpsolvers.strategy_smoother(ssm=ssm) solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) # Solve - sol = ivpsolve.solve_fixed_grid( - lambda *a, **kw: vf(*a, **kw, p=p), - init, - grid=grid, - solver=solver_ts0, - ssm=ssm, - ) + sol = ivpsolve.solve_fixed_grid(init, grid=grid, solver=solver_ts0, ssm=ssm) # Evaluate loss marginal_likelihood = stats.log_marginal_likelihood( diff --git a/docs/examples_advanced/parameter_estimation_blackjax.py b/docs/examples_advanced/parameter_estimation_blackjax.py index afecaf9b8..48c8ab297 100644 --- a/docs/examples_advanced/parameter_estimation_blackjax.py +++ b/docs/examples_advanced/parameter_estimation_blackjax.py @@ -186,10 +186,10 @@ def solve_fixed(theta, *, ts): init, ibm, ssm = ivpsolvers.prior_wiener_integrated( tcoeffs, output_scale=output_scale, ssm_fact="isotropic" ) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) - return ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver, ssm=ssm) + return ivpsolve.solve_fixed_grid(init, grid=ts, solver=solver, ssm=ssm) @jax.jit @@ -201,12 +201,12 @@ def solve_adaptive(theta, *, save_at): init, ibm, ssm = ivpsolvers.prior_wiener_integrated( tcoeffs, output_scale=output_scale, ssm_fact="isotropic" ) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm) return ivpsolve.solve_adaptive_save_at( - vf, init, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm + init, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm ) diff --git a/docs/examples_advanced/parameter_estimation_optax.py b/docs/examples_advanced/parameter_estimation_optax.py index a742cacde..0103a8bb4 100644 --- a/docs/examples_advanced/parameter_estimation_optax.py +++ b/docs/examples_advanced/parameter_estimation_optax.py @@ -62,12 +62,10 @@ def solve(p): init, ibm, ssm = ivpsolvers.prior_wiener_integrated( tcoeffs, output_scale=output_scale, ssm_fact="isotropic" ) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(lambda y, t: vf(y, t, p=p), ssm=ssm) strategy = ivpsolvers.strategy_smoother(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) - return ivpsolve.solve_fixed_grid( - lambda y, t: vf(y, t, p=p), init, grid=ts, solver=solver, ssm=ssm - ) + return ivpsolve.solve_fixed_grid(init, grid=ts, solver=solver, ssm=ssm) parameter_true = f_args + 0.05 diff --git a/docs/examples_basic/conditioning_on_zero_residual.py b/docs/examples_basic/conditioning_on_zero_residual.py index 939aec4f8..ce82168b4 100644 --- a/docs/examples_basic/conditioning_on_zero_residual.py +++ b/docs/examples_basic/conditioning_on_zero_residual.py @@ -74,14 +74,14 @@ def vector_field(y, t): # noqa: ARG001 init, ibm, ssm = ivpsolvers.prior_wiener_integrated( tcoeffs, output_scale=1.0, ssm_fact="dense" ) -ts1 = ivpsolvers.correction_ts1(ssm=ssm) +ts1 = ivpsolvers.correction_ts1(vector_field, ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts1, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-1, rtol=1e-2, ssm=ssm) dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,)) sol = ivpsolve.solve_adaptive_save_at( - vector_field, init, save_at=ts, dt0=1.0, adaptive_solver=adaptive_solver, ssm=ssm + init, save_at=ts, dt0=1.0, adaptive_solver=adaptive_solver, ssm=ssm ) markov_seq_posterior = stats.markov_select_terminal(sol.posterior) diff --git a/docs/examples_basic/dynamic_output_scales.py b/docs/examples_basic/dynamic_output_scales.py index f4a1192c2..b8e6c3321 100644 --- a/docs/examples_basic/dynamic_output_scales.py +++ b/docs/examples_basic/dynamic_output_scales.py @@ -67,7 +67,7 @@ def vf(*ys, t): # noqa: ARG001 init, ibm, ssm = ivpsolvers.prior_wiener_integrated( tcoeffs, output_scale=1.0, ssm_fact="dense" ) -ts1 = ivpsolvers.correction_ts1(ssm=ssm) +ts1 = ivpsolvers.correction_ts1(vf, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) dynamic = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm) mle = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts1, ssm=ssm) @@ -79,8 +79,8 @@ def vf(*ys, t): # noqa: ARG001 ts = jnp.linspace(t0, t1, num=num_pts, endpoint=True) -solution_dynamic = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=dynamic, ssm=ssm) -solution_mle = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=mle, ssm=ssm) +solution_dynamic = ivpsolve.solve_fixed_grid(init, grid=ts, solver=dynamic, ssm=ssm) +solution_mle = ivpsolve.solve_fixed_grid(init, grid=ts, solver=mle, ssm=ssm) # - # Plot the solution. diff --git a/docs/examples_basic/posterior_uncertainties.py b/docs/examples_basic/posterior_uncertainties.py index 3cc77e329..c1341bd69 100644 --- a/docs/examples_basic/posterior_uncertainties.py +++ b/docs/examples_basic/posterior_uncertainties.py @@ -43,7 +43,7 @@ def vf(y, *, t): # noqa: ARG001 # To all users: Try replacing the fixedpoint-smoother with a filter! tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=3) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense") -ts = ivpsolvers.correction_ts1(ssm=ssm) +ts = ivpsolvers.correction_ts1(vf, ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-1, rtol=1e-1, ssm=ssm) @@ -51,7 +51,7 @@ def vf(y, *, t): # noqa: ARG001 # Solve the ODE ts = jnp.linspace(t0, t1, endpoint=True, num=50) sol = ivpsolve.solve_adaptive_save_at( - vf, init, save_at=ts, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm + init, save_at=ts, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm ) # Calibrate diff --git a/docs/examples_basic/second_order_problems.py b/docs/examples_basic/second_order_problems.py index fda834099..a96a3b629 100644 --- a/docs/examples_basic/second_order_problems.py +++ b/docs/examples_basic/second_order_problems.py @@ -47,7 +47,7 @@ def vf_1(y, t): # noqa: ARG001 init, ibm, ssm = ivpsolvers.prior_wiener_integrated( tcoeffs, output_scale=1.0, ssm_fact="isotropic" ) -ts0 = ivpsolvers.correction_ts0(ssm=ssm) +ts0 = ivpsolvers.correction_ts0(vf_1, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver_1st = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver_1st = ivpsolvers.adaptive(solver_1st, atol=1e-5, rtol=1e-5, ssm=ssm) @@ -56,7 +56,7 @@ def vf_1(y, t): # noqa: ARG001 # - solution = ivpsolve.solve_adaptive_save_every_step( - vf_1, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_1st, ssm=ssm + init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_1st, ssm=ssm ) norm = jnp.linalg.norm((solution.u[0][-1] - u0) / jnp.abs(1.0 + u0)) @@ -82,7 +82,7 @@ def vf_2(y, dy, t): # noqa: ARG001 init, ibm, ssm = ivpsolvers.prior_wiener_integrated( tcoeffs, output_scale=1.0, ssm_fact="isotropic" ) -ts0 = ivpsolvers.correction_ts0(ode_order=2, ssm=ssm) +ts0 = ivpsolvers.correction_ts0(vf_2, ode_order=2, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver_2nd = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver_2nd = ivpsolvers.adaptive(solver_2nd, atol=1e-5, rtol=1e-5, ssm=ssm) @@ -90,7 +90,7 @@ def vf_2(y, dy, t): # noqa: ARG001 # - solution = ivpsolve.solve_adaptive_save_every_step( - vf_2, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_2nd, ssm=ssm + init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_2nd, ssm=ssm ) norm = jnp.linalg.norm((solution.u[0][-1, ...] - u0) / jnp.abs(1.0 + u0)) diff --git a/docs/examples_basic/taylor_coefficients.py b/docs/examples_basic/taylor_coefficients.py index 05211b986..d18cbcf01 100644 --- a/docs/examples_basic/taylor_coefficients.py +++ b/docs/examples_basic/taylor_coefficients.py @@ -58,13 +58,13 @@ def vf(*y, t): # noqa: ARG001 def solve(tc): """Solve the ODE.""" init, prior, ssm = ivpsolvers.prior_wiener_integrated(tc, ssm_fact="dense") - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm) ts = jnp.linspace(t0, t1, endpoint=True, num=10) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) return ivpsolve.solve_adaptive_save_at( - vf, init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm + init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm ) diff --git a/docs/examples_quickstart/quickstart.py b/docs/examples_quickstart/quickstart.py index 8fd4c01a8..d02f5cea9 100644 --- a/docs/examples_quickstart/quickstart.py +++ b/docs/examples_quickstart/quickstart.py @@ -43,7 +43,7 @@ def vf(y, *, t): # noqa: ARG001 # Build a solver -ts = ivpsolvers.correction_ts1(ssm=ssm, ode_order=1) +ts = ivpsolvers.correction_ts1(vf, ssm=ssm, ode_order=1) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver_mle(ssm=ssm, strategy=strategy, prior=ibm, correction=ts) adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm) @@ -52,7 +52,7 @@ def vf(y, *, t): # noqa: ARG001 # Solve the ODE # To all users: Try different solution routines. solution = ivpsolve.solve_adaptive_save_every_step( - vf, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm + init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm ) # Look at the solution diff --git a/makefile b/makefile index b065c84ce..3de1189b0 100644 --- a/makefile +++ b/makefile @@ -54,5 +54,14 @@ doc: make benchmarks-plot-results JUPYTER_PLATFORM_DIRS=1 mkdocs build +doc-serve: + # The readme is the landing page of the docs: + cp README.md docs/index.md + # Execute the examples manually and not via mkdocs-jupyter + # to gain clear error messages. + make example + make benchmarks-plot-results + JUPYTER_PLATFORM_DIRS=1 mkdocs serve + find-dead-code: vulture . --ignore-names case*,fixture*,*jvp --exclude probdiffeq/_version.py diff --git a/probdiffeq/ivpsolve.py b/probdiffeq/ivpsolve.py index e0123164c..35dfea0ad 100644 --- a/probdiffeq/ivpsolve.py +++ b/probdiffeq/ivpsolve.py @@ -85,12 +85,11 @@ def _sol_unflatten(aux, children): def solve_adaptive_terminal_values( - vector_field, ssm_init, t0, t1, adaptive_solver, dt0, *, ssm + ssm_init, /, *, t0, t1, adaptive_solver, dt0, ssm ) -> IVPSolution: """Simulate the terminal values of an initial value problem.""" save_at = np.asarray([t0, t1]) solution = solve_adaptive_save_at( - vector_field, ssm_init, save_at=save_at, adaptive_solver=adaptive_solver, @@ -102,7 +101,7 @@ def solve_adaptive_terminal_values( def solve_adaptive_save_at( - vector_field, ssm_init, save_at, adaptive_solver, dt0, *, ssm, warn=True + ssm_init, /, *, save_at, adaptive_solver, dt0, ssm, warn=True ) -> IVPSolution: r"""Solve an initial value problem and return the solution at a pre-determined grid. @@ -138,7 +137,6 @@ def solve_adaptive_save_at( warnings.warn(msg, stacklevel=1) (_t, solution_save_at), _, num_steps = _solve_adaptive_save_at( - tree_util.Partial(vector_field), save_at[0], ssm_init, save_at=save_at[1:], @@ -166,9 +164,7 @@ def solve_adaptive_save_at( ) -def _solve_adaptive_save_at( - vector_field, t, ssm_init, *, save_at, adaptive_solver, dt0 -): +def _solve_adaptive_save_at(t, ssm_init, *, save_at, adaptive_solver, dt0): def advance(state, t_next): # Advance until accepted.t >= t_next. # Note: This could already be the case and we may not loop (just interpolate) @@ -180,9 +176,7 @@ def cond_fun(s): return s.step_from.t + adaptive_solver.eps < t_next def body_fun(s): - return adaptive_solver.rejection_loop( - s, vector_field=vector_field, t1=t_next - ) + return adaptive_solver.rejection_loop(s, t1=t_next) state = control_flow.while_loop(cond_fun, body_fun, init=state) @@ -203,7 +197,7 @@ def body_fun(s): def solve_adaptive_save_every_step( - vector_field, ssm_init, t0, t1, adaptive_solver, dt0, *, ssm + ssm_init, /, *, t0, t1, adaptive_solver, dt0, ssm ) -> IVPSolution: """Solve an initial value problem and save every step. @@ -220,12 +214,7 @@ def solve_adaptive_save_every_step( warnings.warn(msg, stacklevel=1) generator = _solution_generator( - tree_util.Partial(vector_field), - t0, - ssm_init, - t1=t1, - adaptive_solver=adaptive_solver, - dt0=dt0, + t0, ssm_init, t1=t1, adaptive_solver=adaptive_solver, dt0=dt0 ) tmp = tree_array_util.tree_stack(list(generator)) (t, solution_every_step), _dt, num_steps = tmp @@ -254,11 +243,11 @@ def solve_adaptive_save_every_step( ) -def _solution_generator(vector_field, t, ssm_init, *, dt0, t1, adaptive_solver): +def _solution_generator(t, ssm_init, *, dt0, t1, adaptive_solver): state = adaptive_solver.init(t, ssm_init, dt=dt0, num_steps=0) while state.step_from.t < t1: - state = adaptive_solver.rejection_loop(state, vector_field=vector_field, t1=t1) + state = adaptive_solver.rejection_loop(state, t1=t1) if state.step_from.t + adaptive_solver.eps < t1: _, solution = adaptive_solver.extract_before_t1(state, t=t1) @@ -274,12 +263,12 @@ def _solution_generator(vector_field, t, ssm_init, *, dt0, t1, adaptive_solver): yield solution -def solve_fixed_grid(vector_field, ssm_init, grid, solver, *, ssm) -> IVPSolution: +def solve_fixed_grid(ssm_init, /, *, grid, solver, ssm) -> IVPSolution: """Solve an initial value problem on a fixed, pre-determined grid.""" # Compute the solution def body_fn(s, dt): - _error, s_new = solver.step(state=s, vector_field=vector_field, dt=dt) + _error, s_new = solver.step(state=s, dt=dt) return s_new, s_new t0 = grid[0] diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 3cb09b59f..e900fd27b 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -500,6 +500,7 @@ class _Correction: ode_order: int ssm: Any linearize: Callable + vector_field: Callable use_re_linearize: bool can_handle_higher_order: bool @@ -509,15 +510,9 @@ def init(self, x, /): y = self.ssm.prototypes.observed() return x, y - def estimate_error(self, rv, /, vector_field, t): + def estimate_error(self, rv, /, t): """Perform all elements of the correction until the error estimate.""" - if self.can_handle_higher_order: - - def f_wrapped(s): - return vector_field(*s, t=t) - else: - f_wrapped = functools.partial(vector_field, t=t) - + f_wrapped = self._parametrize_vector_field(t=t) A, b = self.linearize(f_wrapped, rv) observed = self.ssm.conditional.marginalise(rv, (A, b)) @@ -532,6 +527,16 @@ def f_wrapped(s): error_estimate = output_scale * error_estimate_unscaled return error_estimate, observed, (A, b, f_wrapped) + def _parametrize_vector_field(self, *, t): + if self.can_handle_higher_order: + + def f_wrapped(s): + return self.vector_field(*s, t=t) + + return f_wrapped + + return functools.partial(self.vector_field, t=t) + def complete(self, rv, cache, /): """Complete what has been left out by `estimate_error`.""" A, b, f_wrapped = cache @@ -541,11 +546,12 @@ def complete(self, rv, cache, /): return corrected, observed -def correction_ts0(*, ssm, ode_order=1, damp: float = 0.0) -> _Correction: +def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Correction: """Zeroth-order Taylor linearisation.""" linearize = ssm.linearise.ode_taylor_0th(ode_order=ode_order, damp=damp) return _Correction( name="TS0", + vector_field=vector_field, ode_order=ode_order, ssm=ssm, linearize=linearize, @@ -554,11 +560,12 @@ def correction_ts0(*, ssm, ode_order=1, damp: float = 0.0) -> _Correction: ) -def correction_ts1(*, ssm, ode_order=1, damp: float = 0.0) -> _Correction: +def correction_ts1(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Correction: """First-order Taylor linearisation.""" linearize = ssm.linearise.ode_taylor_1st(ode_order=ode_order, damp=damp) return _Correction( name="TS1", + vector_field=vector_field, ode_order=ode_order, ssm=ssm, linearize=linearize, @@ -568,12 +575,13 @@ def correction_ts1(*, ssm, ode_order=1, damp: float = 0.0) -> _Correction: def correction_slr0( - *, ssm, cubature_fun=cubature_third_order_spherical, damp: float = 0.0 + vector_field, *, ssm, cubature_fun=cubature_third_order_spherical, damp: float = 0.0 ) -> _Correction: """Zeroth-order statistical linear regression.""" linearize = ssm.linearise.ode_statistical_0th(cubature_fun, damp=damp) return _Correction( ssm=ssm, + vector_field=vector_field, ode_order=1, linearize=linearize, name="SLR0", @@ -583,12 +591,13 @@ def correction_slr0( def correction_slr1( - *, ssm, cubature_fun=cubature_third_order_spherical, damp: float = 0.0 + vector_field, *, ssm, cubature_fun=cubature_third_order_spherical, damp: float = 0.0 ) -> _Correction: """First-order statistical linear regression.""" linearize = ssm.linearise.ode_statistical_1st(cubature_fun, damp=damp) return _Correction( ssm=ssm, + vector_field=vector_field, ode_order=1, linearize=linearize, name="SLR1", @@ -674,10 +683,8 @@ def init(self, t, init) -> _State: calib_state = self.calibration.init() return _State(t=t, hidden=rv, aux_extra=extra, output_scale=calib_state) - def step(self, state: _State, *, vector_field, dt): - return self.step_implementation( - state, vector_field=vector_field, dt=dt, calibration=self.calibration - ) + def step(self, state: _State, *, dt): + return self.step_implementation(state, dt=dt, calibration=self.calibration) def extract(self, state: _State, /): hidden = state.hidden @@ -745,7 +752,7 @@ def solver_mle(strategy, *, correction, prior, ssm): after solving if the MLE-calibration shall be *used*. """ - def step_mle(state, /, *, dt, vector_field, calibration): + def step_mle(state, /, *, dt, calibration): output_scale_prior, _calibrated = calibration.extract(state.output_scale) prior_discretized = prior(dt) @@ -753,9 +760,7 @@ def step_mle(state, /, *, dt, vector_field, calibration): state.hidden, state.aux_extra, prior_discretized=prior_discretized ) t = state.t + dt - error, _, corr = correction.estimate_error( - hidden, vector_field=vector_field, t=t - ) + error, _, corr = correction.estimate_error(hidden, t=t) hidden, extra = strategy.complete( hidden, extra, output_scale=output_scale_prior @@ -804,15 +809,13 @@ def extract(state, /): def solver_dynamic(strategy, *, correction, prior, ssm): """Create a solver that calibrates the output scale dynamically.""" - def step_dynamic(state, /, *, dt, vector_field, calibration): + def step_dynamic(state, /, *, dt, calibration): prior_discretized = prior(dt) hidden, extra = strategy.begin( state.hidden, state.aux_extra, prior_discretized=prior_discretized ) t = state.t + dt - error, observed, corr = correction.estimate_error( - hidden, vector_field=vector_field, t=t - ) + error, observed, corr = correction.estimate_error(hidden, t=t) output_scale = calibration.update(state.output_scale, observed=observed) @@ -851,7 +854,7 @@ def extract(state, /): def solver(strategy, *, correction, prior, ssm): """Create a solver that does not calibrate the output scale automatically.""" - def step(state: _State, *, vector_field, dt, calibration): + def step(state: _State, *, dt, calibration): del calibration # unused prior_discretized = prior(dt) @@ -859,9 +862,7 @@ def step(state: _State, *, vector_field, dt, calibration): state.hidden, state.aux_extra, prior_discretized=prior_discretized ) t = state.t + dt - error, _, corr = correction.estimate_error( - hidden, vector_field=vector_field, t=t - ) + error, _, corr = correction.estimate_error(hidden, t=t) hidden, extra = strategy.complete( hidden, extra, output_scale=state.output_scale @@ -985,7 +986,7 @@ def init(self, t, initial_condition, dt, num_steps) -> _AdaState: ) @functools.jit - def rejection_loop(self, state0: _AdaState, *, vector_field, t1) -> _AdaState: + def rejection_loop(self, state0: _AdaState, *, t1) -> _AdaState: class _RejectionState(containers.NamedTuple): """State for rejection loops. @@ -1031,7 +1032,7 @@ def body_fn(state: _RejectionState) -> _RejectionState: # Perform the actual step. error_estimate, state_proposed = self.solver.step( - state=state.step_from, vector_field=vector_field, dt=dt + state=state.step_from, dt=dt ) # Normalise the error u_proposed = self.ssm.stats.qoi(state_proposed.hidden)[0] diff --git a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py index 7034af433..193e6b4fd 100644 --- a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py @@ -18,29 +18,20 @@ class Taylor(containers.NamedTuple): tcoeffs = Taylor(*taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=2)) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) - asolver = ivpsolvers.adaptive(solver, ssm=ssm, atol=1e-2, rtol=1e-2, clip_dt=True) - args = (vf, init) - - adaptive_kwargs = { - "t0": t0, - "t1": t1, - "dt0": 0.1, - "adaptive_solver": asolver, - "ssm": ssm, - } solution_adaptive = ivpsolve.solve_adaptive_save_every_step( - *args, **adaptive_kwargs + init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=asolver, ssm=ssm ) assert isinstance(solution_adaptive.u, Taylor) grid_adaptive = solution_adaptive.t - fixed_kwargs = {"grid": grid_adaptive, "solver": solver, "ssm": ssm} - solution_fixed = ivpsolve.solve_fixed_grid(*args, **fixed_kwargs) + solution_fixed = ivpsolve.solve_fixed_grid( + init, grid=grid_adaptive, solver=solver, ssm=ssm + ) assert testing.tree_all_allclose(solution_adaptive, solution_fixed) # Assert u and u_std have matching shapes (that was wrong before) diff --git a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py index 235605e7b..0d4f902d4 100644 --- a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py @@ -13,18 +13,15 @@ def test_save_at_result_matches_interpolated_adaptive_result(fact): # Generate a solver tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=2) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) - problem_args = (vf, init) - adaptive_kwargs = {"adaptive_solver": adaptive_solver, "dt0": 0.1, "ssm": ssm} - # Compute an adaptive solution and interpolate ts = np.linspace(t0, t1, num=15, endpoint=True) solution_adaptive = ivpsolve.solve_adaptive_save_every_step( - *problem_args, t0=t0, t1=t1, **adaptive_kwargs + init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm ) u_interp, marginals_interp = stats.offgrid_marginals_searchsorted( ts=ts[1:-1], solution=solution_adaptive, solver=solver @@ -32,7 +29,7 @@ def test_save_at_result_matches_interpolated_adaptive_result(fact): # Compute a save-at solution and remove the edge-points solution_save_at = ivpsolve.solve_adaptive_save_at( - *problem_args, save_at=ts, **adaptive_kwargs + init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm ) u_save_at = tree_util.tree_map(lambda s: s[1:-1], solution_save_at.u) diff --git a/tests/test_ivpsolve/test_save_every_step.py b/tests/test_ivpsolve/test_save_every_step.py index 65a41bca3..93dec9e76 100644 --- a/tests/test_ivpsolve/test_save_every_step.py +++ b/tests/test_ivpsolve/test_save_every_step.py @@ -29,7 +29,7 @@ def python_loop_solution(ivp, *, fact, strategy_fun): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=4) init, transition, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = strategy_fun(ssm=ssm) solver = ivpsolvers.solver_mle(strategy, prior=transition, correction=ts0, ssm=ssm) @@ -42,15 +42,9 @@ def python_loop_solution(ivp, *, fact, strategy_fun): dt0 = ivpsolve.dt0_adaptive( vf, u0, t0=t0, atol=1e-2, rtol=1e-2, error_contraction_rate=5 ) - args = (vf, init) - kwargs = { - "t0": t0, - "t1": t1, - "adaptive_solver": adaptive_solver, - "dt0": dt0, - "ssm": ssm, - } - return ivpsolve.solve_adaptive_save_every_step(*args, **kwargs) + return ivpsolve.solve_adaptive_save_every_step( + init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=dt0, ssm=ssm + ) def reference_solution(ivp, ts): diff --git a/tests/test_ivpsolve/test_solution_api.py b/tests/test_ivpsolve/test_solution_api.py index b24474ab9..4f74a35f8 100644 --- a/tests/test_ivpsolve/test_solution_api.py +++ b/tests/test_ivpsolve/test_solution_api.py @@ -22,12 +22,12 @@ def fixture_pn_solution(fact): tcoeffs = Taylor(*taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=2)) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) asolver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) return ivpsolve.solve_adaptive_save_every_step( - vf, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=asolver, ssm=ssm + init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=asolver, ssm=ssm ) diff --git a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py index d4ac4ebc2..dd6c975a1 100644 --- a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py @@ -14,18 +14,18 @@ def test_terminal_values_identical(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=2) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) asolver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) - - args = (vf, init) - kwargs = {"t0": t0, "t1": t1, "adaptive_solver": asolver, "dt0": 0.1, "ssm": ssm} - - solution_loop = ivpsolve.solve_adaptive_save_every_step(*args, **kwargs) + solution_loop = ivpsolve.solve_adaptive_save_every_step( + init, t0=t0, t1=t1, adaptive_solver=asolver, dt0=0.1, ssm=ssm + ) expected = tree_util.tree_map(lambda s: s[-1], solution_loop) - received = ivpsolve.solve_adaptive_terminal_values(*args, **kwargs) + received = ivpsolve.solve_adaptive_terminal_values( + init, t0=t0, t1=t1, adaptive_solver=asolver, dt0=0.1, ssm=ssm + ) assert testing.tree_all_allclose(received, expected) # Assert u and u_std have matching shapes (that was wrong before) diff --git a/tests/test_ivpsolvers/test_calibration_dynamic_vs_mle.py b/tests/test_ivpsolvers/test_calibration_dynamic_vs_mle.py index 75cd63b85..a6fb78598 100644 --- a/tests/test_ivpsolvers/test_calibration_dynamic_vs_mle.py +++ b/tests/test_ivpsolvers/test_calibration_dynamic_vs_mle.py @@ -16,14 +16,12 @@ def test_exponential_approximated_well(fact): tcoeffs = (*u0, vf(*u0, t=t0)) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts0, ssm=ssm) - problem_args = (vf, init) grid = np.linspace(t0, t1, num=20) - solver_kwargs = {"grid": grid, "solver": solver, "ssm": ssm} - approximation = ivpsolve.solve_fixed_grid(*problem_args, **solver_kwargs) + approximation = ivpsolve.solve_fixed_grid(init, grid=grid, solver=solver, ssm=ssm) solution = ode.odeint_and_save_at( vf, u0, save_at=np.asarray([t0, t1]), atol=1e-5, rtol=1e-5 diff --git a/tests/test_ivpsolvers/test_calibration_mle_vs_none.py b/tests/test_ivpsolvers/test_calibration_mle_vs_none.py index dc8a885e6..51e19a65e 100644 --- a/tests/test_ivpsolvers/test_calibration_mle_vs_none.py +++ b/tests/test_ivpsolvers/test_calibration_mle_vs_none.py @@ -17,13 +17,13 @@ def case_solve_fixed_grid(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=4) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) - kwargs = {"grid": np.linspace(t0, t1, endpoint=True, num=5), "ssm": ssm} + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) + grid = np.linspace(t0, t1, endpoint=True, num=5) def solver_to_solution(solver_fun, strategy_fun): strategy = strategy_fun(ssm=ssm) solver = solver_fun(strategy, prior=ibm, correction=ts0, ssm=ssm) - return ivpsolve.solve_fixed_grid(vf, init, solver=solver, **kwargs) + return ivpsolve.solve_fixed_grid(init, solver=solver, grid=grid, ssm=ssm) return solver_to_solution, ssm @@ -37,17 +37,15 @@ def case_solve_adaptive_save_at(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=4) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) - + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) save_at = np.linspace(t0, t1, endpoint=True, num=5) - kwargs = {"save_at": save_at, "dt0": dt0, "ssm": ssm} def solver_to_solution(solver_fun, strategy_fun): strategy = strategy_fun(ssm=ssm) solver = solver_fun(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) return ivpsolve.solve_adaptive_save_at( - vf, init, adaptive_solver=adaptive_solver, **kwargs + init, adaptive_solver=adaptive_solver, save_at=save_at, dt0=dt0, ssm=ssm ) return solver_to_solution, ssm @@ -62,15 +60,14 @@ def case_solve_adaptive_save_every_step(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=4) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) - kwargs = {"t0": t0, "t1": t1, "dt0": dt0, "ssm": ssm} + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) def solver_to_solution(solver_fun, strategy_fun): strategy = strategy_fun(ssm=ssm) solver = solver_fun(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) return ivpsolve.solve_adaptive_save_every_step( - vf, init, adaptive_solver=adaptive_solver, **kwargs + init, adaptive_solver=adaptive_solver, t0=t0, t1=t1, dt0=dt0, ssm=ssm ) return solver_to_solution, ssm @@ -84,16 +81,14 @@ def case_simulate_terminal_values(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=4) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) - - kwargs = {"t0": t0, "t1": t1, "dt0": dt0, "ssm": ssm} + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) def solver_to_solution(solver_fun, strategy_fun): strategy = strategy_fun(ssm=ssm) solver = solver_fun(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm, atol=1e-2, rtol=1e-2) return ivpsolve.solve_adaptive_terminal_values( - vf, init, adaptive_solver=adaptive_solver, **kwargs + init, adaptive_solver=adaptive_solver, t0=t0, t1=t1, dt0=dt0, ssm=ssm ) return solver_to_solution, ssm diff --git a/tests/test_ivpsolvers/test_corrections.py b/tests/test_ivpsolvers/test_corrections.py index 68f381c09..6612028af 100644 --- a/tests/test_ivpsolvers/test_corrections.py +++ b/tests/test_ivpsolvers/test_corrections.py @@ -41,7 +41,7 @@ def fixture_solution(correction_impl, fact): try: tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=2) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - corr = correction_impl(ssm=ssm, damp=1e-9) + corr = correction_impl(vf, ssm=ssm, damp=1e-9) except NotImplementedError: testing.skip(reason="This type of linearisation has not been implemented.") @@ -49,11 +49,8 @@ def fixture_solution(correction_impl, fact): strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=corr, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) - - adaptive_kwargs = {"adaptive_solver": adaptive_solver, "dt0": 0.1, "ssm": ssm} - return ivpsolve.solve_adaptive_terminal_values( - vf, init, t0=t0, t1=t1, **adaptive_kwargs + init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm ) diff --git a/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py b/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py index 0d813f36b..0108a82f6 100644 --- a/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py +++ b/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py @@ -21,11 +21,11 @@ def fixture_filter_solution(solver_setup): init, ibm, ssm = ivpsolvers.prior_wiener_integrated( tcoeffs, ssm_fact=solver_setup["fact"] ) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(solver_setup["vf"], ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) return ivpsolve.solve_fixed_grid( - solver_setup["vf"], init, grid=solver_setup["grid"], solver=solver, ssm=ssm + init, grid=solver_setup["grid"], solver=solver, ssm=ssm ) @@ -35,11 +35,11 @@ def fixture_smoother_solution(solver_setup): init, ibm, ssm = ivpsolvers.prior_wiener_integrated( tcoeffs, ssm_fact=solver_setup["fact"] ) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(solver_setup["vf"], ssm=ssm) strategy = ivpsolvers.strategy_smoother(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) return ivpsolve.solve_fixed_grid( - solver_setup["vf"], init, grid=solver_setup["grid"], solver=solver, ssm=ssm + init, grid=solver_setup["grid"], solver=solver, ssm=ssm ) diff --git a/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py b/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py index b3e1514fc..7f7418848 100644 --- a/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py +++ b/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py @@ -21,12 +21,11 @@ def fixture_solver_setup(fact): def fixture_solution_smoother(solver_setup): tcoeffs, fact = solver_setup["tcoeffs"], solver_setup["fact"] init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(solver_setup["vf"], ssm=ssm) strategy = ivpsolvers.strategy_smoother(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-3, rtol=1e-3, ssm=ssm) return ivpsolve.solve_adaptive_save_every_step( - solver_setup["vf"], init, t0=solver_setup["t0"], t1=solver_setup["t1"], @@ -40,19 +39,14 @@ def test_fixedpoint_smoother_equivalent_same_grid(solver_setup, solution_smoothe """Test that with save_at=smoother_solution.t, the results should be identical.""" tcoeffs, fact = solver_setup["tcoeffs"], solver_setup["fact"] init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(solver_setup["vf"], ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-3, rtol=1e-3, ssm=ssm) save_at = solution_smoother.t solution_fixedpoint = ivpsolve.solve_adaptive_save_at( - solver_setup["vf"], - init, - save_at=save_at, - adaptive_solver=adaptive_solver, - dt0=0.1, - ssm=ssm, + init, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm ) assert testing.tree_all_allclose(solution_fixedpoint, solution_smoother) @@ -64,7 +58,7 @@ def test_fixedpoint_smoother_equivalent_different_grid(solver_setup, solution_sm # Re-generate the smoothing solver tcoeffs, fact = solver_setup["tcoeffs"], solver_setup["fact"] _init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(solver_setup["vf"], ssm=ssm) strategy = ivpsolvers.strategy_smoother(ssm=ssm) solver_smoother = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) @@ -77,17 +71,12 @@ def test_fixedpoint_smoother_equivalent_different_grid(solver_setup, solution_sm # Generate a fixedpoint solver and solve (saving at the interpolation points) tcoeffs, fact = solver_setup["tcoeffs"], solver_setup["fact"] init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(solver_setup["vf"], ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-3, rtol=1e-3, ssm=ssm) solution_fixedpoint = ivpsolve.solve_adaptive_save_at( - solver_setup["vf"], - init, - save_at=ts, - adaptive_solver=adaptive_solver, - dt0=0.1, - ssm=ssm, + init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm ) # Extract the interior points of the save_at solution diff --git a/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py b/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py index 52fd69ad7..095ca8ccb 100644 --- a/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py +++ b/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py @@ -12,14 +12,14 @@ def test_warning_for_fixedpoint_in_save_every_step_mode(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) with testing.warns(): _ = ivpsolve.solve_adaptive_save_every_step( - vf, init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm + init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm ) @@ -29,13 +29,12 @@ def test_warning_for_smoother_in_save_at_mode(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_smoother(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) with testing.warns(): _ = ivpsolve.solve_adaptive_save_at( - vf, init, save_at=np.linspace(t0, t1), adaptive_solver=adaptive_solver, diff --git a/tests/test_stats/test_log_marginal_likelihood.py b/tests/test_stats/test_log_marginal_likelihood.py index 06085c6c4..c070d31e2 100644 --- a/tests/test_stats/test_log_marginal_likelihood.py +++ b/tests/test_stats/test_log_marginal_likelihood.py @@ -13,13 +13,13 @@ def fixture_solution(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) save_at = np.linspace(t0, t1, endpoint=True, num=4) sol = ivpsolve.solve_adaptive_save_at( - vf, init, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm + init, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm ) return sol, ssm @@ -92,12 +92,12 @@ def test_raises_error_for_filter(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) grid = np.linspace(t0, t1, num=3) - sol = ivpsolve.solve_fixed_grid(vf, init, grid=grid, solver=solver, ssm=ssm) + sol = ivpsolve.solve_fixed_grid(init, grid=grid, solver=solver, ssm=ssm) data = sol.u[0] + 0.1 std = np.ones((sol.u[0].shape[0],)) # values irrelevant diff --git a/tests/test_stats/test_log_marginal_likelihood_terminal_values.py b/tests/test_stats/test_log_marginal_likelihood_terminal_values.py index 4a325fef1..c49067f31 100644 --- a/tests/test_stats/test_log_marginal_likelihood_terminal_values.py +++ b/tests/test_stats/test_log_marginal_likelihood_terminal_values.py @@ -28,12 +28,12 @@ def fixture_solution(strategy_func, fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = strategy_func(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) sol = ivpsolve.solve_adaptive_terminal_values( - vf, init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm + init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm ) return sol, ssm diff --git a/tests/test_stats/test_offgrid_marginals.py b/tests/test_stats/test_offgrid_marginals.py index eebb819f3..5398d7a16 100644 --- a/tests/test_stats/test_offgrid_marginals.py +++ b/tests/test_stats/test_offgrid_marginals.py @@ -12,11 +12,11 @@ def test_filter_marginals_close_only_to_left_boundary(fact): tcoeffs = (u0, vf(u0, t=t0)) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) grid = np.linspace(t0, t1, endpoint=True, num=5) - sol = ivpsolve.solve_fixed_grid(vf, init, grid=grid, solver=solver, ssm=ssm) + sol = ivpsolve.solve_fixed_grid(init, grid=grid, solver=solver, ssm=ssm) # Extrapolate from the left: close-to-left boundary must be similar, # but close-to-right boundary needs not be similar @@ -34,12 +34,12 @@ def test_smoother_marginals_close_to_both_boundaries(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_smoother(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) grid = np.linspace(t0, t1, endpoint=True, num=5) - sol = ivpsolve.solve_fixed_grid(vf, init, grid=grid, solver=solver, ssm=ssm) + sol = ivpsolve.solve_fixed_grid(init, grid=grid, solver=solver, ssm=ssm) # Extrapolate from the left: close-to-left boundary must be similar, # and close-to-right boundary must be similar ts = np.linspace(sol.t[-2] + 1e-4, sol.t[-1] - 1e-4, num=5, endpoint=True) diff --git a/tests/test_stats/test_sample.py b/tests/test_stats/test_sample.py index cafa015b3..514e6d93e 100644 --- a/tests/test_stats/test_sample.py +++ b/tests/test_stats/test_sample.py @@ -11,12 +11,12 @@ def fixture_approximation(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) + ts0 = ivpsolvers.correction_ts0(vf, ssm=ssm) strategy = ivpsolvers.strategy_smoother(ssm=ssm) solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) return ivpsolve.solve_adaptive_save_every_step( - vf, init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm + init, t0=t0, t1=t1, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm )