Skip to content

Commit 8cf4855

Browse files
authored
Move clipping from Controller to AdaptiveSolver (#825)
1 parent b22d330 commit 8cf4855

6 files changed

Lines changed: 94 additions & 91 deletions

File tree

docs/benchmarks/hires/run_hires.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ def param_to_solution(tol):
9393
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
9494
strategy = ivpsolvers.strategy_filter(ssm=ssm)
9595
solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm)
96-
control = ivpsolvers.control_proportional_integral(clip=True)
96+
control = ivpsolvers.control_proportional_integral()
9797
adaptive_solver = ivpsolvers.adaptive(
98-
solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm
98+
solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm, clip_dt=True
9999
)
100100

101101
# Initial state

docs/benchmarks/vanderpol/run_vanderpol.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ def param_to_solution(tol):
8888
solver = ivpsolvers.solver_dynamic(
8989
strategy, prior=ibm, correction=ts0_or_ts1, ssm=ssm
9090
)
91-
control = ivpsolvers.control_proportional_integral(clip=True)
91+
control = ivpsolvers.control_proportional_integral()
9292
adaptive_solver = ivpsolvers.adaptive(
93-
solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm
93+
solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm, clip_dt=True
9494
)
9595

9696
# Initial state

probdiffeq/ivpsolvers.py

Lines changed: 83 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -927,17 +927,34 @@ def extract(state, /):
927927
return _Calibration(init=init, update=update, extract=extract)
928928

929929

930-
def adaptive(slvr, /, *, ssm, atol=1e-4, rtol=1e-2, control=None, norm_ord=None):
930+
def adaptive(
931+
slvr,
932+
/,
933+
*,
934+
ssm,
935+
atol=1e-4,
936+
rtol=1e-2,
937+
control=None,
938+
norm_ord=None,
939+
clip_dt: bool = False,
940+
):
931941
"""Make an IVP solver adaptive."""
932942
if control is None:
933943
control = control_proportional_integral()
934944

935945
return _AdaSolver(
936-
slvr, ssm=ssm, atol=atol, rtol=rtol, control=control, norm_ord=norm_ord
946+
slvr,
947+
ssm=ssm,
948+
atol=atol,
949+
rtol=rtol,
950+
control=control,
951+
norm_ord=norm_ord,
952+
clip_dt=clip_dt,
937953
)
938954

939955

940956
class _AdaState(containers.NamedTuple):
957+
dt: float
941958
step_from: Any
942959
interp_from: Any
943960
control: Any
@@ -948,14 +965,24 @@ class _AdaSolver:
948965
"""Adaptive IVP solvers."""
949966

950967
def __init__(
951-
self, slvr: _ProbabilisticSolver, /, *, atol, rtol, control, norm_ord, ssm
968+
self,
969+
slvr: _ProbabilisticSolver,
970+
/,
971+
*,
972+
atol,
973+
rtol,
974+
control,
975+
norm_ord,
976+
ssm,
977+
clip_dt: bool,
952978
):
953979
self.solver = slvr
954980
self.atol = atol
955981
self.rtol = rtol
956982
self.control = control
957983
self.norm_ord = norm_ord
958984
self.ssm = ssm
985+
self.clip_dt = clip_dt
959986

960987
def __repr__(self):
961988
return (
@@ -973,7 +1000,7 @@ def init(self, t, initial_condition, dt, num_steps) -> _AdaState:
9731000
"""Initialise the IVP solver state."""
9741001
state_solver = self.solver.init(t, initial_condition)
9751002
state_control = self.control.init(dt)
976-
return _AdaState(state_solver, state_solver, state_control, num_steps)
1003+
return _AdaState(dt, state_solver, state_solver, state_control, num_steps)
9771004

9781005
@functools.jit
9791006
def rejection_loop(self, state0: _AdaState, *, vector_field, t1) -> _AdaState:
@@ -984,6 +1011,7 @@ class _RejectionState(containers.NamedTuple):
9841011
This is one part of an IVP solver step.)
9851012
"""
9861013

1014+
dt: float
9871015
error_norm_proposed: float
9881016
control: Any
9891017
proposed: Any
@@ -996,6 +1024,7 @@ def _inf_like(tree):
9961024
smaller_than_1 = 1.0 / 1.1 # the cond() must return True
9971025
return _RejectionState(
9981026
error_norm_proposed=smaller_than_1,
1027+
dt=s0.dt,
9991028
control=s0.control,
10001029
proposed=_inf_like(s0.step_from),
10011030
step_from=s0.step_from,
@@ -1011,15 +1040,16 @@ def body_fn(state: _RejectionState) -> _RejectionState:
10111040
Perform a step with an IVP solver and
10121041
propose a future time-step based on tolerances and error estimates.
10131042
"""
1043+
dt = state.dt
1044+
10141045
# Some controllers like to clip the terminal value instead of interpolating.
10151046
# This must happen _before_ the step.
1016-
state_control = self.control.clip(state.control, t=state.step_from.t, t1=t1)
1047+
if self.clip_dt:
1048+
dt = np.minimum(dt, t1 - state.step_from.t)
10171049

10181050
# Perform the actual step.
10191051
error_estimate, state_proposed = self.solver.step(
1020-
state=state.step_from,
1021-
vector_field=vector_field,
1022-
dt=self.control.extract(state_control),
1052+
state=state.step_from, vector_field=vector_field, dt=dt
10231053
)
10241054
# Normalise the error
10251055
u_proposed = self.ssm.stats.qoi(state_proposed.hidden)[0]
@@ -1028,8 +1058,11 @@ def body_fn(state: _RejectionState) -> _RejectionState:
10281058
error_power = _error_scale_and_normalize(error_estimate, u=u)
10291059

10301060
# Propose a new step
1031-
state_control = self.control.apply(state_control, error_power=error_power)
1061+
dt, state_control = self.control.apply(
1062+
dt, state.control, error_power=error_power
1063+
)
10321064
return _RejectionState(
1065+
dt=dt,
10331066
error_norm_proposed=error_power, # new
10341067
proposed=state_proposed, # new
10351068
control=state_control, # new
@@ -1044,17 +1077,16 @@ def _error_scale_and_normalize(error_estimate, *, u):
10441077
return error_norm_rel ** (-1.0 / self.solver.error_contraction_rate)
10451078

10461079
def extract(s: _RejectionState) -> _AdaState:
1047-
num_steps = state0.stats + 1
1048-
return _AdaState(s.proposed, s.step_from, s.control, num_steps)
1080+
num_steps = state0.stats + 1.0 # TODO: track step attempts as well
1081+
return _AdaState(s.dt, s.proposed, s.step_from, s.control, num_steps)
10491082

10501083
init_val = init(state0)
10511084
state_new = control_flow.while_loop(cond_fn, body_fn, init_val)
10521085
return extract(state_new)
10531086

10541087
def extract_before_t1(self, state: _AdaState):
10551088
solution_solver = self.solver.extract(state.step_from)
1056-
solution_control = self.control.extract(state.control)
1057-
return solution_solver, solution_control, state.stats
1089+
return solution_solver, (state.dt, state.control), state.stats
10581090

10591091
def extract_at_t1(self, state: _AdaState):
10601092
# todo: make the "at t1" decision inside interpolate(),
@@ -1063,37 +1095,47 @@ def extract_at_t1(self, state: _AdaState):
10631095
interp_from=state.interp_from, interp_to=state.step_from
10641096
)
10651097
state = _AdaState(
1066-
interp.step_from, interp.interp_from, state.control, state.stats
1098+
state.dt, interp.step_from, interp.interp_from, state.control, state.stats
10671099
)
10681100

10691101
solution_solver = self.solver.extract(interp.interpolated)
1070-
solution_control = self.control.extract(state.control)
1071-
return state, (solution_solver, solution_control, state.stats)
1102+
return state, (solution_solver, (state.dt, state.control), state.stats)
10721103

10731104
def extract_after_t1_via_interpolation(self, state: _AdaState, t):
10741105
interp = self.solver.interpolate(
10751106
t, interp_from=state.interp_from, interp_to=state.step_from
10761107
)
10771108
state = _AdaState(
1078-
interp.step_from, interp.interp_from, state.control, state.stats
1109+
state.dt, interp.step_from, interp.interp_from, state.control, state.stats
10791110
)
10801111

10811112
solution_solver = self.solver.extract(interp.interpolated)
1082-
solution_control = self.control.extract(state.control)
1083-
return state, (solution_solver, solution_control, state.stats)
1113+
return state, (solution_solver, (state.dt, state.control), state.stats)
10841114

10851115
@staticmethod
10861116
def register_pytree_node():
10871117
def _asolver_flatten(asolver):
10881118
children = (asolver.atol, asolver.rtol)
1089-
aux = (asolver.solver, asolver.control, asolver.norm_ord, asolver.ssm)
1119+
aux = (
1120+
asolver.solver,
1121+
asolver.control,
1122+
asolver.norm_ord,
1123+
asolver.ssm,
1124+
asolver.clip_dt,
1125+
)
10901126
return children, aux
10911127

10921128
def _asolver_unflatten(aux, children):
10931129
atol, rtol = children
1094-
(slvr, control, norm_ord, ssm) = aux
1130+
(slvr, control, norm_ord, ssm, clip_dt) = aux
10951131
return _AdaSolver(
1096-
slvr, atol=atol, rtol=rtol, control=control, norm_ord=norm_ord, ssm=ssm
1132+
slvr,
1133+
atol=atol,
1134+
rtol=rtol,
1135+
control=control,
1136+
norm_ord=norm_ord,
1137+
ssm=ssm,
1138+
clip_dt=clip_dt,
10971139
)
10981140

10991141
tree_util.register_pytree_node(
@@ -1103,46 +1145,35 @@ def _asolver_unflatten(aux, children):
11031145

11041146
_AdaSolver.register_pytree_node()
11051147

1148+
T = TypeVar("T")
1149+
11061150

11071151
@containers.dataclass
1108-
class _Controller:
1152+
class _Controller(Generic[T]):
11091153
"""Control algorithm."""
11101154

1111-
init: Callable[[float], Any]
1155+
init: Callable[[float], T]
11121156
"""Initialise the controller state."""
11131157

1114-
clip: Callable[[Any, float, float], Any]
1115-
"""(Optionally) clip the current step to not exceed t1."""
1116-
1117-
apply: Callable[[Any, NamedArg(float, "error_power")], Any]
1158+
apply: Callable[[float, T, NamedArg(float, "error_power")], tuple[float, T]]
11181159
r"""Propose a time-step $\Delta t$."""
11191160

1120-
extract: Callable[[Any], float]
1121-
"""Extract the time-step from the controller state."""
1122-
11231161

11241162
def control_proportional_integral(
11251163
*,
1126-
clip: bool = False,
11271164
safety=0.95,
11281165
factor_min=0.2,
11291166
factor_max=10.0,
11301167
power_integral_unscaled=0.3,
11311168
power_proportional_unscaled=0.4,
1132-
) -> _Controller:
1169+
) -> _Controller[float]:
11331170
"""Construct a proportional-integral-controller with time-clipping."""
11341171

1135-
class PIState(containers.NamedTuple):
1136-
dt: float
1137-
error_power_previously_accepted: float
1138-
1139-
def init(dt: float, /) -> PIState:
1140-
return PIState(dt, 1.0)
1172+
def init(_dt: float, /) -> float:
1173+
return 1.0
11411174

1142-
def apply(state: PIState, /, *, error_power) -> PIState:
1175+
def apply(dt: float, error_power_prev: float, /, *, error_power):
11431176
# error_power = error_norm ** (-1.0 / error_contraction_rate)
1144-
dt_proposed, error_power_prev = state
1145-
11461177
a1 = error_power**power_integral_unscaled
11471178
a2 = (error_power / error_power_prev) ** power_proportional_unscaled
11481179
scale_factor_unclipped = safety * a1 * a2
@@ -1153,50 +1184,26 @@ def apply(state: PIState, /, *, error_power) -> PIState:
11531184
# >= 1.0 because error_power is 1/scaled_error_norm
11541185
error_power_prev = np.where(error_power >= 1.0, error_power, error_power_prev)
11551186

1156-
dt_proposed = scale_factor * dt_proposed
1157-
return PIState(dt_proposed, error_power_prev)
1158-
1159-
def extract(state: PIState, /) -> float:
1160-
dt_proposed, _error_norm_previously_accepted = state
1161-
return dt_proposed
1162-
1163-
if clip:
1164-
1165-
def clip_fun(state: PIState, /, t, t1) -> PIState:
1166-
dt_proposed, error_norm_previously_accepted = state
1167-
dt = dt_proposed
1168-
dt_clipped = np.minimum(dt, t1 - t)
1169-
return PIState(dt_clipped, error_norm_previously_accepted)
1170-
1171-
return _Controller(init=init, apply=apply, extract=extract, clip=clip_fun)
1187+
dt_proposed = scale_factor * dt
1188+
return dt_proposed, error_power_prev
11721189

1173-
return _Controller(init=init, apply=apply, extract=extract, clip=lambda v, **_kw: v)
1190+
return _Controller(init=init, apply=apply)
11741191

11751192

11761193
def control_integral(
1177-
*, clip=False, safety=0.95, factor_min=0.2, factor_max=10.0
1178-
) -> _Controller:
1194+
*, safety=0.95, factor_min=0.2, factor_max=10.0
1195+
) -> _Controller[None]:
11791196
"""Construct an integral-controller."""
11801197

1181-
def init(dt, /):
1182-
return dt
1198+
def init(_dt, /) -> None:
1199+
return None
11831200

1184-
def apply(dt, /, *, error_power):
1201+
def apply(dt, _state, /, *, error_power):
11851202
# error_power = error_norm ** (-1.0 / error_contraction_rate)
11861203
scale_factor_unclipped = safety * error_power
11871204

11881205
scale_factor_clipped_min = np.minimum(scale_factor_unclipped, factor_max)
11891206
scale_factor = np.maximum(factor_min, scale_factor_clipped_min)
1190-
return scale_factor * dt
1191-
1192-
def extract(dt, /):
1193-
return dt
1194-
1195-
if clip:
1196-
1197-
def clip_fun(dt, /, t, t1):
1198-
return np.minimum(dt, t1 - t)
1199-
1200-
return _Controller(init=init, apply=apply, extract=extract, clip=clip_fun)
1207+
return scale_factor * dt, None
12011208

1202-
return _Controller(init=init, apply=apply, extract=extract, clip=lambda v, **_kw: v)
1209+
return _Controller(init=init, apply=apply)

tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@ class Taylor(containers.NamedTuple):
2222
strategy = ivpsolvers.strategy_filter(ssm=ssm)
2323
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)
2424

25-
control = ivpsolvers.control_integral(clip=True) # Any clipped controller will do.
26-
asolver = ivpsolvers.adaptive(
27-
solver, ssm=ssm, atol=1e-2, rtol=1e-2, control=control
28-
)
25+
asolver = ivpsolvers.adaptive(solver, ssm=ssm, atol=1e-2, rtol=1e-2, clip_dt=True)
2926

3027
init = solver.initial_condition()
3128
args = (vf, init)

tests/test_ivpsolve/test_save_every_step.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@ def python_loop_solution(ivp, *, fact, strategy_fun):
3535

3636
# clip=False because we need to test adaptive-step-interpolation
3737
# for smoothers
38-
control = ivpsolvers.control_proportional_integral(clip=False)
3938
adaptive_solver = ivpsolvers.adaptive(
40-
solver, atol=1e-2, rtol=1e-2, control=control, ssm=ssm
39+
solver, atol=1e-2, rtol=1e-2, ssm=ssm, clip_dt=False
4140
)
4241

4342
dt0 = ivpsolve.dt0_adaptive(

tests/test_ivpsolvers/test_controllers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ def test_equivalence_pi_vs_i(dt, error_power, num_applies):
1515
ctrl_i = ivpsolvers.control_integral()
1616

1717
x_pi = ctrl_pi.init(dt)
18+
dt_pi = dt
1819
for _ in range(num_applies):
19-
x_pi = ctrl_pi.apply(x_pi, error_power=error_power)
20-
x_pi = ctrl_pi.extract(x_pi)
20+
dt_pi, x_pi = ctrl_pi.apply(dt_pi, x_pi, error_power=error_power)
2121

2222
x_i = ctrl_i.init(dt)
23+
dt_i = dt
2324
for _ in range(num_applies):
24-
x_i = ctrl_i.apply(x_i, error_power=error_power)
25-
x_i = ctrl_i.extract(x_i)
26-
assert np.allclose(x_i, x_pi)
25+
dt_i, x_i = ctrl_i.apply(dt_i, x_i, error_power=error_power)
26+
assert np.allclose(dt_i, dt_pi)

0 commit comments

Comments
 (0)