Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/benchmarks/hires/run_hires.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def param_to_solution(tol):
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm)
control = ivpsolvers.control_proportional_integral(clip=True)
control = ivpsolvers.control_proportional_integral()
adaptive_solver = ivpsolvers.adaptive(
solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm
solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm, clip_dt=True
)

# Initial state
Expand Down
4 changes: 2 additions & 2 deletions docs/benchmarks/vanderpol/run_vanderpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def param_to_solution(tol):
solver = ivpsolvers.solver_dynamic(
strategy, prior=ibm, correction=ts0_or_ts1, ssm=ssm
)
control = ivpsolvers.control_proportional_integral(clip=True)
control = ivpsolvers.control_proportional_integral()
adaptive_solver = ivpsolvers.adaptive(
solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm
solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm, clip_dt=True
)

# Initial state
Expand Down
159 changes: 83 additions & 76 deletions probdiffeq/ivpsolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,17 +921,34 @@ def extract(state, /):
return _Calibration(init=init, update=update, extract=extract)


def adaptive(slvr, /, *, ssm, atol=1e-4, rtol=1e-2, control=None, norm_ord=None):
def adaptive(
slvr,
/,
*,
ssm,
atol=1e-4,
rtol=1e-2,
control=None,
norm_ord=None,
clip_dt: bool = False,
):
"""Make an IVP solver adaptive."""
if control is None:
control = control_proportional_integral()

return _AdaSolver(
slvr, ssm=ssm, atol=atol, rtol=rtol, control=control, norm_ord=norm_ord
slvr,
ssm=ssm,
atol=atol,
rtol=rtol,
control=control,
norm_ord=norm_ord,
clip_dt=clip_dt,
)


class _AdaState(containers.NamedTuple):
dt: float
step_from: Any
interp_from: Any
control: Any
Expand All @@ -942,14 +959,24 @@ class _AdaSolver:
"""Adaptive IVP solvers."""

def __init__(
self, slvr: _ProbabilisticSolver, /, *, atol, rtol, control, norm_ord, ssm
self,
slvr: _ProbabilisticSolver,
/,
*,
atol,
rtol,
control,
norm_ord,
ssm,
clip_dt: bool,
):
self.solver = slvr
self.atol = atol
self.rtol = rtol
self.control = control
self.norm_ord = norm_ord
self.ssm = ssm
self.clip_dt = clip_dt

def __repr__(self):
return (
Expand All @@ -967,7 +994,7 @@ def init(self, t, initial_condition, dt, num_steps) -> _AdaState:
"""Initialise the IVP solver state."""
state_solver = self.solver.init(t, initial_condition)
state_control = self.control.init(dt)
return _AdaState(state_solver, state_solver, state_control, num_steps)
return _AdaState(dt, state_solver, state_solver, state_control, num_steps)

@functools.jit
def rejection_loop(self, state0: _AdaState, *, vector_field, t1) -> _AdaState:
Expand All @@ -978,6 +1005,7 @@ class _RejectionState(containers.NamedTuple):
This is one part of an IVP solver step.)
"""

dt: float
error_norm_proposed: float
control: Any
proposed: Any
Expand All @@ -990,6 +1018,7 @@ def _inf_like(tree):
smaller_than_1 = 1.0 / 1.1 # the cond() must return True
return _RejectionState(
error_norm_proposed=smaller_than_1,
dt=s0.dt,
control=s0.control,
proposed=_inf_like(s0.step_from),
step_from=s0.step_from,
Expand All @@ -1005,15 +1034,16 @@ def body_fn(state: _RejectionState) -> _RejectionState:
Perform a step with an IVP solver and
propose a future time-step based on tolerances and error estimates.
"""
dt = state.dt

# Some controllers like to clip the terminal value instead of interpolating.
# This must happen _before_ the step.
state_control = self.control.clip(state.control, t=state.step_from.t, t1=t1)
if self.clip_dt:
dt = np.minimum(dt, t1 - state.step_from.t)

# Perform the actual step.
error_estimate, state_proposed = self.solver.step(
state=state.step_from,
vector_field=vector_field,
dt=self.control.extract(state_control),
state=state.step_from, vector_field=vector_field, dt=dt
)
# Normalise the error
u_proposed = self.ssm.stats.qoi(state_proposed.hidden)[0]
Expand All @@ -1022,8 +1052,11 @@ def body_fn(state: _RejectionState) -> _RejectionState:
error_power = _error_scale_and_normalize(error_estimate, u=u)

# Propose a new step
state_control = self.control.apply(state_control, error_power=error_power)
dt, state_control = self.control.apply(
dt, state.control, error_power=error_power
)
return _RejectionState(
dt=dt,
error_norm_proposed=error_power, # new
proposed=state_proposed, # new
control=state_control, # new
Expand All @@ -1038,17 +1071,16 @@ def _error_scale_and_normalize(error_estimate, *, u):
return error_norm_rel ** (-1.0 / self.solver.error_contraction_rate)

def extract(s: _RejectionState) -> _AdaState:
num_steps = state0.stats + 1
return _AdaState(s.proposed, s.step_from, s.control, num_steps)
num_steps = state0.stats + 1.0 # TODO: track step attempts as well
return _AdaState(s.dt, s.proposed, s.step_from, s.control, num_steps)

init_val = init(state0)
state_new = control_flow.while_loop(cond_fn, body_fn, init_val)
return extract(state_new)

def extract_before_t1(self, state: _AdaState):
solution_solver = self.solver.extract(state.step_from)
solution_control = self.control.extract(state.control)
return solution_solver, solution_control, state.stats
return solution_solver, (state.dt, state.control), state.stats

def extract_at_t1(self, state: _AdaState):
# todo: make the "at t1" decision inside interpolate(),
Expand All @@ -1057,37 +1089,47 @@ def extract_at_t1(self, state: _AdaState):
interp_from=state.interp_from, interp_to=state.step_from
)
state = _AdaState(
interp.step_from, interp.interp_from, state.control, state.stats
state.dt, interp.step_from, interp.interp_from, state.control, state.stats
)

solution_solver = self.solver.extract(interp.interpolated)
solution_control = self.control.extract(state.control)
return state, (solution_solver, solution_control, state.stats)
return state, (solution_solver, (state.dt, state.control), state.stats)

def extract_after_t1_via_interpolation(self, state: _AdaState, t):
interp = self.solver.interpolate(
t, interp_from=state.interp_from, interp_to=state.step_from
)
state = _AdaState(
interp.step_from, interp.interp_from, state.control, state.stats
state.dt, interp.step_from, interp.interp_from, state.control, state.stats
)

solution_solver = self.solver.extract(interp.interpolated)
solution_control = self.control.extract(state.control)
return state, (solution_solver, solution_control, state.stats)
return state, (solution_solver, (state.dt, state.control), state.stats)

@staticmethod
def register_pytree_node():
def _asolver_flatten(asolver):
children = (asolver.atol, asolver.rtol)
aux = (asolver.solver, asolver.control, asolver.norm_ord, asolver.ssm)
aux = (
asolver.solver,
asolver.control,
asolver.norm_ord,
asolver.ssm,
asolver.clip_dt,
)
return children, aux

def _asolver_unflatten(aux, children):
atol, rtol = children
(slvr, control, norm_ord, ssm) = aux
(slvr, control, norm_ord, ssm, clip_dt) = aux
return _AdaSolver(
slvr, atol=atol, rtol=rtol, control=control, norm_ord=norm_ord, ssm=ssm
slvr,
atol=atol,
rtol=rtol,
control=control,
norm_ord=norm_ord,
ssm=ssm,
clip_dt=clip_dt,
)

tree_util.register_pytree_node(
Expand All @@ -1097,46 +1139,35 @@ def _asolver_unflatten(aux, children):

_AdaSolver.register_pytree_node()

T = TypeVar("T")


@containers.dataclass
class _Controller:
class _Controller(Generic[T]):
"""Control algorithm."""

init: Callable[[float], Any]
init: Callable[[float], T]
"""Initialise the controller state."""

clip: Callable[[Any, float, float], Any]
"""(Optionally) clip the current step to not exceed t1."""

apply: Callable[[Any, NamedArg(float, "error_power")], Any]
apply: Callable[[float, T, NamedArg(float, "error_power")], tuple[float, T]]
r"""Propose a time-step $\Delta t$."""

extract: Callable[[Any], float]
"""Extract the time-step from the controller state."""


def control_proportional_integral(
*,
clip: bool = False,
safety=0.95,
factor_min=0.2,
factor_max=10.0,
power_integral_unscaled=0.3,
power_proportional_unscaled=0.4,
) -> _Controller:
) -> _Controller[float]:
"""Construct a proportional-integral-controller with time-clipping."""

class PIState(containers.NamedTuple):
dt: float
error_power_previously_accepted: float

def init(dt: float, /) -> PIState:
return PIState(dt, 1.0)
def init(_dt: float, /) -> float:
return 1.0

def apply(state: PIState, /, *, error_power) -> PIState:
def apply(dt: float, error_power_prev: float, /, *, error_power):
# error_power = error_norm ** (-1.0 / error_contraction_rate)
dt_proposed, error_power_prev = state

a1 = error_power**power_integral_unscaled
a2 = (error_power / error_power_prev) ** power_proportional_unscaled
scale_factor_unclipped = safety * a1 * a2
Expand All @@ -1147,50 +1178,26 @@ def apply(state: PIState, /, *, error_power) -> PIState:
# >= 1.0 because error_power is 1/scaled_error_norm
error_power_prev = np.where(error_power >= 1.0, error_power, error_power_prev)

dt_proposed = scale_factor * dt_proposed
return PIState(dt_proposed, error_power_prev)

def extract(state: PIState, /) -> float:
dt_proposed, _error_norm_previously_accepted = state
return dt_proposed

if clip:

def clip_fun(state: PIState, /, t, t1) -> PIState:
dt_proposed, error_norm_previously_accepted = state
dt = dt_proposed
dt_clipped = np.minimum(dt, t1 - t)
return PIState(dt_clipped, error_norm_previously_accepted)

return _Controller(init=init, apply=apply, extract=extract, clip=clip_fun)
dt_proposed = scale_factor * dt
return dt_proposed, error_power_prev

return _Controller(init=init, apply=apply, extract=extract, clip=lambda v, **_kw: v)
return _Controller(init=init, apply=apply)


def control_integral(
*, clip=False, safety=0.95, factor_min=0.2, factor_max=10.0
) -> _Controller:
*, safety=0.95, factor_min=0.2, factor_max=10.0
) -> _Controller[None]:
"""Construct an integral-controller."""

def init(dt, /):
return dt
def init(_dt, /) -> None:
return None

def apply(dt, /, *, error_power):
def apply(dt, _state, /, *, error_power):
# error_power = error_norm ** (-1.0 / error_contraction_rate)
scale_factor_unclipped = safety * error_power

scale_factor_clipped_min = np.minimum(scale_factor_unclipped, factor_max)
scale_factor = np.maximum(factor_min, scale_factor_clipped_min)
return scale_factor * dt

def extract(dt, /):
return dt

if clip:

def clip_fun(dt, /, t, t1):
return np.minimum(dt, t1 - t)

return _Controller(init=init, apply=apply, extract=extract, clip=clip_fun)
return scale_factor * dt, None

return _Controller(init=init, apply=apply, extract=extract, clip=lambda v, **_kw: v)
return _Controller(init=init, apply=apply)
5 changes: 1 addition & 4 deletions tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ class Taylor(containers.NamedTuple):
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)

control = ivpsolvers.control_integral(clip=True) # Any clipped controller will do.
asolver = ivpsolvers.adaptive(
solver, ssm=ssm, atol=1e-2, rtol=1e-2, control=control
)
asolver = ivpsolvers.adaptive(solver, ssm=ssm, atol=1e-2, rtol=1e-2, clip_dt=True)

init = solver.initial_condition()
args = (vf, init)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_ivpsolve/test_save_every_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ def python_loop_solution(ivp, *, fact, strategy_fun):

# clip=False because we need to test adaptive-step-interpolation
# for smoothers
control = ivpsolvers.control_proportional_integral(clip=False)
adaptive_solver = ivpsolvers.adaptive(
solver, atol=1e-2, rtol=1e-2, control=control, ssm=ssm
solver, atol=1e-2, rtol=1e-2, ssm=ssm, clip_dt=False
)

dt0 = ivpsolve.dt0_adaptive(
Expand Down
10 changes: 5 additions & 5 deletions tests/test_ivpsolvers/test_controllers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ def test_equivalence_pi_vs_i(dt, error_power, num_applies):
ctrl_i = ivpsolvers.control_integral()

x_pi = ctrl_pi.init(dt)
dt_pi = dt
for _ in range(num_applies):
x_pi = ctrl_pi.apply(x_pi, error_power=error_power)
x_pi = ctrl_pi.extract(x_pi)
dt_pi, x_pi = ctrl_pi.apply(dt_pi, x_pi, error_power=error_power)

x_i = ctrl_i.init(dt)
dt_i = dt
for _ in range(num_applies):
x_i = ctrl_i.apply(x_i, error_power=error_power)
x_i = ctrl_i.extract(x_i)
assert np.allclose(x_i, x_pi)
dt_i, x_i = ctrl_i.apply(dt_i, x_i, error_power=error_power)
assert np.allclose(dt_i, dt_pi)