Skip to content

Commit d83f097

Browse files
authored
Move vector field from solution routines to corrections (#829)
* Move vector_field to corrections * Update tests * Update example notebooks * Update benchmarks * Rerun benchmarks
1 parent 313d711 commit d83f097

45 files changed

Lines changed: 142 additions & 221 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

docs/benchmarks/hires/plot_ts.npy

-592 Bytes
Binary file not shown.

docs/benchmarks/hires/plot_ys.npy

-4.63 KB
Binary file not shown.

docs/benchmarks/hires/results.npy

0 Bytes
Binary file not shown.

docs/benchmarks/hires/run_hires.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def param_to_solution(tol):
9090
vf_auto = functools.partial(vf_probdiffeq, t=t0)
9191
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives)
9292
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
93-
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
93+
ts1 = ivpsolvers.correction_ts1(vf_probdiffeq, ssm=ssm)
9494
strategy = ivpsolvers.strategy_filter(ssm=ssm)
9595
solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm)
9696
control = ivpsolvers.control_proportional_integral()
@@ -101,13 +101,7 @@ def param_to_solution(tol):
101101
# Solve
102102
dt0 = ivpsolve.dt0(vf_auto, (u0,))
103103
solution = ivpsolve.solve_adaptive_terminal_values(
104-
vf_probdiffeq,
105-
init,
106-
t0=t0,
107-
t1=t1,
108-
dt0=dt0,
109-
adaptive_solver=adaptive_solver,
110-
ssm=ssm,
104+
init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm
111105
)
112106

113107
# Return the terminal value
0 Bytes
Binary file not shown.

docs/benchmarks/lotkavolterra/run_lotkavolterra.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def param_to_solution(tol):
8484
tcoeffs, ssm_fact=implementation
8585
)
8686
strategy = ivpsolvers.strategy_filter(ssm=ssm)
87-
corr = correction(ssm=ssm)
87+
corr = correction(vf_probdiffeq, ssm=ssm)
8888
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=corr, ssm=ssm)
8989
control = ivpsolvers.control_proportional_integral()
9090
adaptive_solver = ivpsolvers.adaptive(
@@ -94,13 +94,7 @@ def param_to_solution(tol):
9494
# Solve
9595
dt0 = ivpsolve.dt0(vf_auto, (u0,))
9696
solution = ivpsolve.solve_adaptive_terminal_values(
97-
vf_probdiffeq,
98-
init,
99-
t0=t0,
100-
t1=t1,
101-
dt0=dt0,
102-
adaptive_solver=adaptive_solver,
103-
ssm=ssm,
97+
init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm
10498
)
10599

106100
# Return the terminal value
176 Bytes
Binary file not shown.
4.81 KB
Binary file not shown.
0 Bytes
Binary file not shown.

docs/benchmarks/pleiades/run_pleiades.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def param_to_solution(tol):
102102
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
103103
tcoeffs, ssm_fact="isotropic"
104104
)
105-
ts0_or_ts1 = correction_fun(ssm=ssm, ode_order=2)
105+
ts0_or_ts1 = correction_fun(vf_probdiffeq, ssm=ssm, ode_order=2)
106106
strategy = ivpsolvers.strategy_filter(ssm=ssm)
107107
solver = ivpsolvers.solver_dynamic(
108108
strategy, prior=ibm, correction=ts0_or_ts1, ssm=ssm
@@ -115,13 +115,7 @@ def param_to_solution(tol):
115115
# Solve
116116
dt0 = ivpsolve.dt0(vf_auto, (u0, du0))
117117
solution = ivpsolve.solve_adaptive_terminal_values(
118-
vf_probdiffeq,
119-
init,
120-
t0=t0,
121-
t1=t1,
122-
dt0=dt0,
123-
adaptive_solver=adaptive_solver,
124-
ssm=ssm,
118+
init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm
125119
)
126120

127121
# Return the terminal value

0 commit comments

Comments
 (0)