diff --git a/docs/benchmarks/hires/run_hires.py b/docs/benchmarks/hires/run_hires.py index 16cfda7f2..48fa74bef 100644 --- a/docs/benchmarks/hires/run_hires.py +++ b/docs/benchmarks/hires/run_hires.py @@ -93,9 +93,9 @@ def param_to_solution(tol): ts1 = ivpsolvers.correction_ts1(ssm=ssm) strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm) - control = ivpsolvers.control_proportional_integral(clip=True) + control = ivpsolvers.control_proportional_integral() adaptive_solver = ivpsolvers.adaptive( - solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm + solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm, clip_dt=True ) # Initial state diff --git a/docs/benchmarks/vanderpol/run_vanderpol.py b/docs/benchmarks/vanderpol/run_vanderpol.py index 1ba8181fe..0fd3e4a6b 100644 --- a/docs/benchmarks/vanderpol/run_vanderpol.py +++ b/docs/benchmarks/vanderpol/run_vanderpol.py @@ -88,9 +88,9 @@ def param_to_solution(tol): solver = ivpsolvers.solver_dynamic( strategy, prior=ibm, correction=ts0_or_ts1, ssm=ssm ) - control = ivpsolvers.control_proportional_integral(clip=True) + control = ivpsolvers.control_proportional_integral() adaptive_solver = ivpsolvers.adaptive( - solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm + solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm, clip_dt=True ) # Initial state diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index b4b597f79..7fee24012 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -921,17 +921,34 @@ def extract(state, /): return _Calibration(init=init, update=update, extract=extract) -def adaptive(slvr, /, *, ssm, atol=1e-4, rtol=1e-2, control=None, norm_ord=None): +def adaptive( + slvr, + /, + *, + ssm, + atol=1e-4, + rtol=1e-2, + control=None, + norm_ord=None, + clip_dt: bool = False, +): """Make an IVP solver adaptive.""" if control is None: control = control_proportional_integral() return _AdaSolver( - slvr, ssm=ssm, atol=atol, rtol=rtol, control=control, norm_ord=norm_ord + slvr, + ssm=ssm, + atol=atol, + rtol=rtol, + control=control, + norm_ord=norm_ord, + clip_dt=clip_dt, ) class _AdaState(containers.NamedTuple): + dt: float step_from: Any interp_from: Any control: Any @@ -942,7 +959,16 @@ class _AdaSolver: """Adaptive IVP solvers.""" def __init__( - self, slvr: _ProbabilisticSolver, /, *, atol, rtol, control, norm_ord, ssm + self, + slvr: _ProbabilisticSolver, + /, + *, + atol, + rtol, + control, + norm_ord, + ssm, + clip_dt: bool, ): self.solver = slvr self.atol = atol @@ -950,6 +976,7 @@ def __init__( self.control = control self.norm_ord = norm_ord self.ssm = ssm + self.clip_dt = clip_dt def __repr__(self): return ( @@ -967,7 +994,7 @@ def init(self, t, initial_condition, dt, num_steps) -> _AdaState: """Initialise the IVP solver state.""" state_solver = self.solver.init(t, initial_condition) state_control = self.control.init(dt) - return _AdaState(state_solver, state_solver, state_control, num_steps) + return _AdaState(dt, state_solver, state_solver, state_control, num_steps) @functools.jit def rejection_loop(self, state0: _AdaState, *, vector_field, t1) -> _AdaState: @@ -978,6 +1005,7 @@ class _RejectionState(containers.NamedTuple): This is one part of an IVP solver step.) """ + dt: float error_norm_proposed: float control: Any proposed: Any @@ -990,6 +1018,7 @@ def _inf_like(tree): smaller_than_1 = 1.0 / 1.1 # the cond() must return True return _RejectionState( error_norm_proposed=smaller_than_1, + dt=s0.dt, control=s0.control, proposed=_inf_like(s0.step_from), step_from=s0.step_from, @@ -1005,15 +1034,16 @@ def body_fn(state: _RejectionState) -> _RejectionState: Perform a step with an IVP solver and propose a future time-step based on tolerances and error estimates. """ + dt = state.dt + # Some controllers like to clip the terminal value instead of interpolating. # This must happen _before_ the step. - state_control = self.control.clip(state.control, t=state.step_from.t, t1=t1) + if self.clip_dt: + dt = np.minimum(dt, t1 - state.step_from.t) # Perform the actual step. error_estimate, state_proposed = self.solver.step( - state=state.step_from, - vector_field=vector_field, - dt=self.control.extract(state_control), + state=state.step_from, vector_field=vector_field, dt=dt ) # Normalise the error u_proposed = self.ssm.stats.qoi(state_proposed.hidden)[0] @@ -1022,8 +1052,11 @@ def body_fn(state: _RejectionState) -> _RejectionState: error_power = _error_scale_and_normalize(error_estimate, u=u) # Propose a new step - state_control = self.control.apply(state_control, error_power=error_power) + dt, state_control = self.control.apply( + dt, state.control, error_power=error_power + ) return _RejectionState( + dt=dt, error_norm_proposed=error_power, # new proposed=state_proposed, # new control=state_control, # new @@ -1038,8 +1071,8 @@ def _error_scale_and_normalize(error_estimate, *, u): return error_norm_rel ** (-1.0 / self.solver.error_contraction_rate) def extract(s: _RejectionState) -> _AdaState: - num_steps = state0.stats + 1 - return _AdaState(s.proposed, s.step_from, s.control, num_steps) + num_steps = state0.stats + 1.0 # TODO: track step attempts as well + return _AdaState(s.dt, s.proposed, s.step_from, s.control, num_steps) init_val = init(state0) state_new = control_flow.while_loop(cond_fn, body_fn, init_val) @@ -1047,8 +1080,7 @@ def extract(s: _RejectionState) -> _AdaState: def extract_before_t1(self, state: _AdaState): solution_solver = self.solver.extract(state.step_from) - solution_control = self.control.extract(state.control) - return solution_solver, solution_control, state.stats + return solution_solver, (state.dt, state.control), state.stats def extract_at_t1(self, state: _AdaState): # todo: make the "at t1" decision inside interpolate(), @@ -1057,37 +1089,47 @@ def extract_at_t1(self, state: _AdaState): interp_from=state.interp_from, interp_to=state.step_from ) state = _AdaState( - interp.step_from, interp.interp_from, state.control, state.stats + state.dt, interp.step_from, interp.interp_from, state.control, state.stats ) solution_solver = self.solver.extract(interp.interpolated) - solution_control = self.control.extract(state.control) - return state, (solution_solver, solution_control, state.stats) + return state, (solution_solver, (state.dt, state.control), state.stats) def extract_after_t1_via_interpolation(self, state: _AdaState, t): interp = self.solver.interpolate( t, interp_from=state.interp_from, interp_to=state.step_from ) state = _AdaState( - interp.step_from, interp.interp_from, state.control, state.stats + state.dt, interp.step_from, interp.interp_from, state.control, state.stats ) solution_solver = self.solver.extract(interp.interpolated) - solution_control = self.control.extract(state.control) - return state, (solution_solver, solution_control, state.stats) + return state, (solution_solver, (state.dt, state.control), state.stats) @staticmethod def register_pytree_node(): def _asolver_flatten(asolver): children = (asolver.atol, asolver.rtol) - aux = (asolver.solver, asolver.control, asolver.norm_ord, asolver.ssm) + aux = ( + asolver.solver, + asolver.control, + asolver.norm_ord, + asolver.ssm, + asolver.clip_dt, + ) return children, aux def _asolver_unflatten(aux, children): atol, rtol = children - (slvr, control, norm_ord, ssm) = aux + (slvr, control, norm_ord, ssm, clip_dt) = aux return _AdaSolver( - slvr, atol=atol, rtol=rtol, control=control, norm_ord=norm_ord, ssm=ssm + slvr, + atol=atol, + rtol=rtol, + control=control, + norm_ord=norm_ord, + ssm=ssm, + clip_dt=clip_dt, ) tree_util.register_pytree_node( @@ -1097,46 +1139,35 @@ def _asolver_unflatten(aux, children): _AdaSolver.register_pytree_node() +T = TypeVar("T") + @containers.dataclass -class _Controller: +class _Controller(Generic[T]): """Control algorithm.""" - init: Callable[[float], Any] + init: Callable[[float], T] """Initialise the controller state.""" - clip: Callable[[Any, float, float], Any] - """(Optionally) clip the current step to not exceed t1.""" - - apply: Callable[[Any, NamedArg(float, "error_power")], Any] + apply: Callable[[float, T, NamedArg(float, "error_power")], tuple[float, T]] r"""Propose a time-step $\Delta t$.""" - extract: Callable[[Any], float] - """Extract the time-step from the controller state.""" - def control_proportional_integral( *, - clip: bool = False, safety=0.95, factor_min=0.2, factor_max=10.0, power_integral_unscaled=0.3, power_proportional_unscaled=0.4, -) -> _Controller: +) -> _Controller[float]: """Construct a proportional-integral-controller with time-clipping.""" - class PIState(containers.NamedTuple): - dt: float - error_power_previously_accepted: float - - def init(dt: float, /) -> PIState: - return PIState(dt, 1.0) + def init(_dt: float, /) -> float: + return 1.0 - def apply(state: PIState, /, *, error_power) -> PIState: + def apply(dt: float, error_power_prev: float, /, *, error_power): # error_power = error_norm ** (-1.0 / error_contraction_rate) - dt_proposed, error_power_prev = state - a1 = error_power**power_integral_unscaled a2 = (error_power / error_power_prev) ** power_proportional_unscaled scale_factor_unclipped = safety * a1 * a2 @@ -1147,50 +1178,26 @@ def apply(state: PIState, /, *, error_power) -> PIState: # >= 1.0 because error_power is 1/scaled_error_norm error_power_prev = np.where(error_power >= 1.0, error_power, error_power_prev) - dt_proposed = scale_factor * dt_proposed - return PIState(dt_proposed, error_power_prev) - - def extract(state: PIState, /) -> float: - dt_proposed, _error_norm_previously_accepted = state - return dt_proposed - - if clip: - - def clip_fun(state: PIState, /, t, t1) -> PIState: - dt_proposed, error_norm_previously_accepted = state - dt = dt_proposed - dt_clipped = np.minimum(dt, t1 - t) - return PIState(dt_clipped, error_norm_previously_accepted) - - return _Controller(init=init, apply=apply, extract=extract, clip=clip_fun) + dt_proposed = scale_factor * dt + return dt_proposed, error_power_prev - return _Controller(init=init, apply=apply, extract=extract, clip=lambda v, **_kw: v) + return _Controller(init=init, apply=apply) def control_integral( - *, clip=False, safety=0.95, factor_min=0.2, factor_max=10.0 -) -> _Controller: + *, safety=0.95, factor_min=0.2, factor_max=10.0 +) -> _Controller[None]: """Construct an integral-controller.""" - def init(dt, /): - return dt + def init(_dt, /) -> None: + return None - def apply(dt, /, *, error_power): + def apply(dt, _state, /, *, error_power): # error_power = error_norm ** (-1.0 / error_contraction_rate) scale_factor_unclipped = safety * error_power scale_factor_clipped_min = np.minimum(scale_factor_unclipped, factor_max) scale_factor = np.maximum(factor_min, scale_factor_clipped_min) - return scale_factor * dt - - def extract(dt, /): - return dt - - if clip: - - def clip_fun(dt, /, t, t1): - return np.minimum(dt, t1 - t) - - return _Controller(init=init, apply=apply, extract=extract, clip=clip_fun) + return scale_factor * dt, None - return _Controller(init=init, apply=apply, extract=extract, clip=lambda v, **_kw: v) + return _Controller(init=init, apply=apply) diff --git a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py index 8a1655156..6ff982c07 100644 --- a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py @@ -22,10 +22,7 @@ class Taylor(containers.NamedTuple): strategy = ivpsolvers.strategy_filter(ssm=ssm) solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) - control = ivpsolvers.control_integral(clip=True) # Any clipped controller will do. - asolver = ivpsolvers.adaptive( - solver, ssm=ssm, atol=1e-2, rtol=1e-2, control=control - ) + asolver = ivpsolvers.adaptive(solver, ssm=ssm, atol=1e-2, rtol=1e-2, clip_dt=True) init = solver.initial_condition() args = (vf, init) diff --git a/tests/test_ivpsolve/test_save_every_step.py b/tests/test_ivpsolve/test_save_every_step.py index 427f7b4c0..c20eb25ff 100644 --- a/tests/test_ivpsolve/test_save_every_step.py +++ b/tests/test_ivpsolve/test_save_every_step.py @@ -35,9 +35,8 @@ def python_loop_solution(ivp, *, fact, strategy_fun): # clip=False because we need to test adaptive-step-interpolation # for smoothers - control = ivpsolvers.control_proportional_integral(clip=False) adaptive_solver = ivpsolvers.adaptive( - solver, atol=1e-2, rtol=1e-2, control=control, ssm=ssm + solver, atol=1e-2, rtol=1e-2, ssm=ssm, clip_dt=False ) dt0 = ivpsolve.dt0_adaptive( diff --git a/tests/test_ivpsolvers/test_controllers.py b/tests/test_ivpsolvers/test_controllers.py index 06aeb19f9..5bce5ee3b 100644 --- a/tests/test_ivpsolvers/test_controllers.py +++ b/tests/test_ivpsolvers/test_controllers.py @@ -15,12 +15,12 @@ def test_equivalence_pi_vs_i(dt, error_power, num_applies): ctrl_i = ivpsolvers.control_integral() x_pi = ctrl_pi.init(dt) + dt_pi = dt for _ in range(num_applies): - x_pi = ctrl_pi.apply(x_pi, error_power=error_power) - x_pi = ctrl_pi.extract(x_pi) + dt_pi, x_pi = ctrl_pi.apply(dt_pi, x_pi, error_power=error_power) x_i = ctrl_i.init(dt) + dt_i = dt for _ in range(num_applies): - x_i = ctrl_i.apply(x_i, error_power=error_power) - x_i = ctrl_i.extract(x_i) - assert np.allclose(x_i, x_pi) + dt_i, x_i = ctrl_i.apply(dt_i, x_i, error_power=error_power) + assert np.allclose(dt_i, dt_pi)