Skip to content

Commit bfc8b93

Browse files
authored
Tidy up some internals in the adaptive solver (#826)
* Keyword arguments in AdaState * Rename extrapolation to strategy according to the rest of the code * All adaptive_solver.extraction methods have the same signature * Leave TODOs * Adaptive solver gets 'eps' to make 'small values' consistent * Recover old default for adaptive_solver.eps * Remove unnecessary function call nesting * Call solve_and_save_at in solve_and_save_terminal_values * Initial condition returns an IVPSolution object now * Delete outdated comments
1 parent 8cf4855 commit bfc8b93

2 files changed

Lines changed: 120 additions & 115 deletions

File tree

probdiffeq/ivpsolve.py

Lines changed: 44 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from probdiffeq.backend import (
55
containers,
66
control_flow,
7-
functools,
87
linalg,
98
tree_array_util,
109
tree_util,
@@ -89,42 +88,21 @@ def solve_adaptive_terminal_values(
8988
vector_field, initial_condition, t0, t1, adaptive_solver, dt0, *, ssm
9089
) -> IVPSolution:
9190
"""Simulate the terminal values of an initial value problem."""
92-
save_at = np.asarray([t1])
93-
(_t, solution_save_at), _, num_steps = _solve_adaptive_save_at(
94-
tree_util.Partial(vector_field),
95-
t0,
91+
save_at = np.asarray([t0, t1])
92+
solution = solve_adaptive_save_at(
93+
vector_field,
9694
initial_condition,
9795
save_at=save_at,
9896
adaptive_solver=adaptive_solver,
9997
dt0=dt0,
100-
)
101-
# "squeeze"-type functionality (there is only a single state!)
102-
squeeze_fun = functools.partial(np.squeeze_along_axis, axis=0)
103-
solution_save_at = tree_util.tree_map(squeeze_fun, solution_save_at)
104-
num_steps = tree_util.tree_map(squeeze_fun, num_steps)
105-
106-
# I think the user expects marginals, so we compute them here
107-
# todo: do this in IVPSolution.* methods?
108-
posterior, output_scale = solution_save_at
109-
marginals = posterior.init if isinstance(posterior, stats.MarkovSeq) else posterior
110-
111-
u = ssm.stats.qoi_from_sample(marginals.mean)
112-
std = ssm.stats.standard_deviation(marginals)
113-
u_std = ssm.stats.qoi_from_sample(std)
114-
return IVPSolution(
115-
t=t1,
116-
u=u,
117-
u_std=u_std,
11898
ssm=ssm,
119-
marginals=marginals,
120-
posterior=posterior,
121-
output_scale=output_scale,
122-
num_steps=num_steps,
99+
warn=False, # Turn off warnings because any solver goes for terminal values
123100
)
101+
return tree_util.tree_map(lambda s: s[-1], solution)
124102

125103

126104
def solve_adaptive_save_at(
127-
vector_field, initial_condition, save_at, adaptive_solver, dt0, *, ssm
105+
vector_field, initial_condition, save_at, adaptive_solver, dt0, *, ssm, warn=True
128106
) -> IVPSolution:
129107
r"""Solve an initial value problem and return the solution at a pre-determined grid.
130108
@@ -152,7 +130,7 @@ def solve_adaptive_save_at(
152130
}
153131
```
154132
"""
155-
if not adaptive_solver.solver.is_suitable_for_save_at:
133+
if not adaptive_solver.solver.is_suitable_for_save_at and warn:
156134
msg = (
157135
f"Strategy {adaptive_solver.solver} should not "
158136
f"be used in solve_adaptive_save_at. "
@@ -170,7 +148,7 @@ def solve_adaptive_save_at(
170148

171149
# I think the user expects the initial condition to be part of the state
172150
# (as well as marginals), so we compute those things here
173-
posterior_t0, *_ = initial_condition
151+
posterior_t0 = initial_condition.posterior
174152
posterior_save_at, output_scale = solution_save_at
175153
_tmp = _userfriendly_output(
176154
posterior=posterior_save_at, posterior_t0=posterior_t0, ssm=ssm
@@ -194,41 +172,37 @@ def solve_adaptive_save_at(
194172
def _solve_adaptive_save_at(
195173
vector_field, t, initial_condition, *, save_at, adaptive_solver, dt0
196174
):
197-
advance_func = functools.partial(
198-
_advance_and_interpolate,
199-
vector_field=vector_field,
200-
adaptive_solver=adaptive_solver,
201-
)
202-
203-
state = adaptive_solver.init(t, initial_condition, dt=dt0, num_steps=0.0)
204-
_, solution = control_flow.scan(advance_func, init=state, xs=save_at, reverse=False)
205-
return solution
206-
207-
208-
def _advance_and_interpolate(state, t_next, *, vector_field, adaptive_solver):
209-
# Advance until accepted.t >= t_next.
210-
# Note: This could already be the case and we may not loop (just interpolate)
211-
def cond_fun(s):
212-
# Terminate the loop if
213-
# the difference from s.t to t_next is smaller than a constant factor
214-
# (which is a "small" multiple of the current machine precision)
215-
# or if s.t > t_next holds.
216-
return s.step_from.t + 10 * np.finfo_eps(float) < t_next
175+
def advance(state, t_next):
176+
# Advance until accepted.t >= t_next.
177+
# Note: This could already be the case and we may not loop (just interpolate)
178+
def cond_fun(s):
179+
# Terminate the loop if
180+
# the difference from s.t to t_next is smaller than a constant factor
181+
# (which is a "small" multiple of the current machine precision)
182+
# or if s.t > t_next holds.
183+
return s.step_from.t + adaptive_solver.eps < t_next
184+
185+
def body_fun(s):
186+
return adaptive_solver.rejection_loop(
187+
s, vector_field=vector_field, t1=t_next
188+
)
217189

218-
def body_fun(s):
219-
return adaptive_solver.rejection_loop(s, vector_field=vector_field, t1=t_next)
190+
state = control_flow.while_loop(cond_fun, body_fun, init=state)
220191

221-
state = control_flow.while_loop(cond_fun, body_fun, init=state)
192+
# Either interpolate (t > t_next) or "finalise" (t == t_next)
193+
is_after_t1 = state.step_from.t > t_next + adaptive_solver.eps
194+
state, solution = control_flow.cond(
195+
is_after_t1,
196+
adaptive_solver.extract_after_t1,
197+
adaptive_solver.extract_at_t1,
198+
state,
199+
t_next,
200+
)
201+
return state, solution
222202

223-
# Either interpolate (t > t_next) or "finalise" (t == t_next)
224-
state, solution = control_flow.cond(
225-
state.step_from.t > t_next + 10 * np.finfo_eps(float),
226-
adaptive_solver.extract_after_t1_via_interpolation,
227-
lambda s, _t: adaptive_solver.extract_at_t1(s),
228-
state,
229-
t_next,
230-
)
231-
return state, solution
203+
state = adaptive_solver.init(t, initial_condition, dt=dt0, num_steps=0.0)
204+
_, solution = control_flow.scan(advance, init=state, xs=save_at, reverse=False)
205+
return solution
232206

233207

234208
def solve_adaptive_save_every_step(
@@ -264,7 +238,7 @@ def solve_adaptive_save_every_step(
264238
t = np.concatenate((np.atleast_1d(t0), t))
265239

266240
# I think the user expects marginals, so we compute them here
267-
posterior_t0, *_ = initial_condition
241+
posterior_t0 = initial_condition.posterior
268242
posterior, output_scale = solution_every_step
269243
_tmp = _userfriendly_output(posterior=posterior, posterior_t0=posterior_t0, ssm=ssm)
270244
marginals, posterior = _tmp
@@ -292,15 +266,16 @@ def _solution_generator(
292266
while state.step_from.t < t1:
293267
state = adaptive_solver.rejection_loop(state, vector_field=vector_field, t1=t1)
294268

295-
if state.step_from.t < t1:
296-
solution = adaptive_solver.extract_before_t1(state)
269+
if state.step_from.t + adaptive_solver.eps < t1:
270+
_, solution = adaptive_solver.extract_before_t1(state, t=t1)
297271
yield solution
298272

299273
# Either interpolate (t > t_next) or "finalise" (t == t_next)
300-
if state.step_from.t > t1:
301-
_, solution = adaptive_solver.extract_after_t1_via_interpolation(state, t=t1)
274+
is_after_t1 = state.step_from.t > t1 + adaptive_solver.eps
275+
if is_after_t1:
276+
_, solution = adaptive_solver.extract_after_t1(state, t=t1)
302277
else:
303-
_, solution = adaptive_solver.extract_at_t1(state)
278+
_, solution = adaptive_solver.extract_at_t1(state, t=t1)
304279

305280
yield solution
306281

@@ -321,7 +296,7 @@ def body_fn(s, dt):
321296
_t, (posterior, output_scale) = solver.extract(result_state)
322297

323298
# I think the user expects marginals, so we compute them here
324-
posterior_t0, *_ = initial_condition
299+
posterior_t0 = initial_condition.posterior
325300
_tmp = _userfriendly_output(posterior=posterior, posterior_t0=posterior_t0, ssm=ssm)
326301
marginals, posterior = _tmp
327302

0 commit comments

Comments
 (0)