Skip to content
113 changes: 44 additions & 69 deletions probdiffeq/ivpsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from probdiffeq.backend import (
containers,
control_flow,
functools,
linalg,
tree_array_util,
tree_util,
Expand Down Expand Up @@ -89,42 +88,21 @@ def solve_adaptive_terminal_values(
vector_field, initial_condition, t0, t1, adaptive_solver, dt0, *, ssm
) -> IVPSolution:
"""Simulate the terminal values of an initial value problem."""
save_at = np.asarray([t1])
(_t, solution_save_at), _, num_steps = _solve_adaptive_save_at(
tree_util.Partial(vector_field),
t0,
save_at = np.asarray([t0, t1])
solution = solve_adaptive_save_at(
vector_field,
initial_condition,
save_at=save_at,
adaptive_solver=adaptive_solver,
dt0=dt0,
)
# "squeeze"-type functionality (there is only a single state!)
squeeze_fun = functools.partial(np.squeeze_along_axis, axis=0)
solution_save_at = tree_util.tree_map(squeeze_fun, solution_save_at)
num_steps = tree_util.tree_map(squeeze_fun, num_steps)

# I think the user expects marginals, so we compute them here
# todo: do this in IVPSolution.* methods?
posterior, output_scale = solution_save_at
marginals = posterior.init if isinstance(posterior, stats.MarkovSeq) else posterior

u = ssm.stats.qoi_from_sample(marginals.mean)
std = ssm.stats.standard_deviation(marginals)
u_std = ssm.stats.qoi_from_sample(std)
return IVPSolution(
t=t1,
u=u,
u_std=u_std,
ssm=ssm,
marginals=marginals,
posterior=posterior,
output_scale=output_scale,
num_steps=num_steps,
warn=False, # Turn off warnings because any solver goes for terminal values
)
return tree_util.tree_map(lambda s: s[-1], solution)


def solve_adaptive_save_at(
vector_field, initial_condition, save_at, adaptive_solver, dt0, *, ssm
vector_field, initial_condition, save_at, adaptive_solver, dt0, *, ssm, warn=True
) -> IVPSolution:
r"""Solve an initial value problem and return the solution at a pre-determined grid.

Expand Down Expand Up @@ -152,7 +130,7 @@ def solve_adaptive_save_at(
}
```
"""
if not adaptive_solver.solver.is_suitable_for_save_at:
if not adaptive_solver.solver.is_suitable_for_save_at and warn:
msg = (
f"Strategy {adaptive_solver.solver} should not "
f"be used in solve_adaptive_save_at. "
Expand All @@ -170,7 +148,7 @@ def solve_adaptive_save_at(

# I think the user expects the initial condition to be part of the state
# (as well as marginals), so we compute those things here
posterior_t0, *_ = initial_condition
posterior_t0 = initial_condition.posterior
posterior_save_at, output_scale = solution_save_at
_tmp = _userfriendly_output(
posterior=posterior_save_at, posterior_t0=posterior_t0, ssm=ssm
Expand All @@ -194,41 +172,37 @@ def solve_adaptive_save_at(
def _solve_adaptive_save_at(
vector_field, t, initial_condition, *, save_at, adaptive_solver, dt0
):
advance_func = functools.partial(
_advance_and_interpolate,
vector_field=vector_field,
adaptive_solver=adaptive_solver,
)

state = adaptive_solver.init(t, initial_condition, dt=dt0, num_steps=0.0)
_, solution = control_flow.scan(advance_func, init=state, xs=save_at, reverse=False)
return solution


def _advance_and_interpolate(state, t_next, *, vector_field, adaptive_solver):
# Advance until accepted.t >= t_next.
# Note: This could already be the case and we may not loop (just interpolate)
def cond_fun(s):
# Terminate the loop if
# the difference from s.t to t_next is smaller than a constant factor
# (which is a "small" multiple of the current machine precision)
# or if s.t > t_next holds.
return s.step_from.t + 10 * np.finfo_eps(float) < t_next
def advance(state, t_next):
# Advance until accepted.t >= t_next.
# Note: This could already be the case and we may not loop (just interpolate)
def cond_fun(s):
# Terminate the loop if
# the difference from s.t to t_next is smaller than a constant factor
# (which is a "small" multiple of the current machine precision)
# or if s.t > t_next holds.
return s.step_from.t + adaptive_solver.eps < t_next

def body_fun(s):
return adaptive_solver.rejection_loop(
s, vector_field=vector_field, t1=t_next
)

def body_fun(s):
return adaptive_solver.rejection_loop(s, vector_field=vector_field, t1=t_next)
state = control_flow.while_loop(cond_fun, body_fun, init=state)

state = control_flow.while_loop(cond_fun, body_fun, init=state)
# Either interpolate (t > t_next) or "finalise" (t == t_next)
is_after_t1 = state.step_from.t > t_next + adaptive_solver.eps
state, solution = control_flow.cond(
is_after_t1,
adaptive_solver.extract_after_t1,
adaptive_solver.extract_at_t1,
state,
t_next,
)
return state, solution

# Either interpolate (t > t_next) or "finalise" (t == t_next)
state, solution = control_flow.cond(
state.step_from.t > t_next + 10 * np.finfo_eps(float),
adaptive_solver.extract_after_t1_via_interpolation,
lambda s, _t: adaptive_solver.extract_at_t1(s),
state,
t_next,
)
return state, solution
state = adaptive_solver.init(t, initial_condition, dt=dt0, num_steps=0.0)
_, solution = control_flow.scan(advance, init=state, xs=save_at, reverse=False)
return solution


def solve_adaptive_save_every_step(
Expand Down Expand Up @@ -264,7 +238,7 @@ def solve_adaptive_save_every_step(
t = np.concatenate((np.atleast_1d(t0), t))

# I think the user expects marginals, so we compute them here
posterior_t0, *_ = initial_condition
posterior_t0 = initial_condition.posterior
posterior, output_scale = solution_every_step
_tmp = _userfriendly_output(posterior=posterior, posterior_t0=posterior_t0, ssm=ssm)
marginals, posterior = _tmp
Expand Down Expand Up @@ -292,15 +266,16 @@ def _solution_generator(
while state.step_from.t < t1:
state = adaptive_solver.rejection_loop(state, vector_field=vector_field, t1=t1)

if state.step_from.t < t1:
solution = adaptive_solver.extract_before_t1(state)
if state.step_from.t + adaptive_solver.eps < t1:
_, solution = adaptive_solver.extract_before_t1(state, t=t1)
yield solution

# Either interpolate (t > t_next) or "finalise" (t == t_next)
if state.step_from.t > t1:
_, solution = adaptive_solver.extract_after_t1_via_interpolation(state, t=t1)
is_after_t1 = state.step_from.t > t1 + adaptive_solver.eps
if is_after_t1:
_, solution = adaptive_solver.extract_after_t1(state, t=t1)
else:
_, solution = adaptive_solver.extract_at_t1(state)
_, solution = adaptive_solver.extract_at_t1(state, t=t1)

yield solution

Expand All @@ -321,7 +296,7 @@ def body_fn(s, dt):
_t, (posterior, output_scale) = solver.extract(result_state)

# I think the user expects marginals, so we compute them here
posterior_t0, *_ = initial_condition
posterior_t0 = initial_condition.posterior
_tmp = _userfriendly_output(posterior=posterior, posterior_t0=posterior_t0, ssm=ssm)
marginals, posterior = _tmp

Expand Down
Loading