@@ -927,17 +927,34 @@ def extract(state, /):
927927 return _Calibration (init = init , update = update , extract = extract )
928928
929929
930- def adaptive (slvr , / , * , ssm , atol = 1e-4 , rtol = 1e-2 , control = None , norm_ord = None ):
930+ def adaptive (
931+ slvr ,
932+ / ,
933+ * ,
934+ ssm ,
935+ atol = 1e-4 ,
936+ rtol = 1e-2 ,
937+ control = None ,
938+ norm_ord = None ,
939+ clip_dt : bool = False ,
940+ ):
931941 """Make an IVP solver adaptive."""
932942 if control is None :
933943 control = control_proportional_integral ()
934944
935945 return _AdaSolver (
936- slvr , ssm = ssm , atol = atol , rtol = rtol , control = control , norm_ord = norm_ord
946+ slvr ,
947+ ssm = ssm ,
948+ atol = atol ,
949+ rtol = rtol ,
950+ control = control ,
951+ norm_ord = norm_ord ,
952+ clip_dt = clip_dt ,
937953 )
938954
939955
940956class _AdaState (containers .NamedTuple ):
957+ dt : float
941958 step_from : Any
942959 interp_from : Any
943960 control : Any
@@ -948,14 +965,24 @@ class _AdaSolver:
948965 """Adaptive IVP solvers."""
949966
950967 def __init__ (
951- self , slvr : _ProbabilisticSolver , / , * , atol , rtol , control , norm_ord , ssm
968+ self ,
969+ slvr : _ProbabilisticSolver ,
970+ / ,
971+ * ,
972+ atol ,
973+ rtol ,
974+ control ,
975+ norm_ord ,
976+ ssm ,
977+ clip_dt : bool ,
952978 ):
953979 self .solver = slvr
954980 self .atol = atol
955981 self .rtol = rtol
956982 self .control = control
957983 self .norm_ord = norm_ord
958984 self .ssm = ssm
985+ self .clip_dt = clip_dt
959986
960987 def __repr__ (self ):
961988 return (
@@ -973,7 +1000,7 @@ def init(self, t, initial_condition, dt, num_steps) -> _AdaState:
9731000 """Initialise the IVP solver state."""
9741001 state_solver = self .solver .init (t , initial_condition )
9751002 state_control = self .control .init (dt )
976- return _AdaState (state_solver , state_solver , state_control , num_steps )
1003+ return _AdaState (dt , state_solver , state_solver , state_control , num_steps )
9771004
9781005 @functools .jit
9791006 def rejection_loop (self , state0 : _AdaState , * , vector_field , t1 ) -> _AdaState :
@@ -984,6 +1011,7 @@ class _RejectionState(containers.NamedTuple):
9841011 This is one part of an IVP solver step.)
9851012 """
9861013
1014+ dt : float
9871015 error_norm_proposed : float
9881016 control : Any
9891017 proposed : Any
@@ -996,6 +1024,7 @@ def _inf_like(tree):
9961024 smaller_than_1 = 1.0 / 1.1 # the cond() must return True
9971025 return _RejectionState (
9981026 error_norm_proposed = smaller_than_1 ,
1027+ dt = s0 .dt ,
9991028 control = s0 .control ,
10001029 proposed = _inf_like (s0 .step_from ),
10011030 step_from = s0 .step_from ,
@@ -1011,15 +1040,16 @@ def body_fn(state: _RejectionState) -> _RejectionState:
10111040 Perform a step with an IVP solver and
10121041 propose a future time-step based on tolerances and error estimates.
10131042 """
1043+ dt = state .dt
1044+
10141045 # Some controllers like to clip the terminal value instead of interpolating.
10151046 # This must happen _before_ the step.
1016- state_control = self .control .clip (state .control , t = state .step_from .t , t1 = t1 )
1047+ if self .clip_dt :
1048+ dt = np .minimum (dt , t1 - state .step_from .t )
10171049
10181050 # Perform the actual step.
10191051 error_estimate , state_proposed = self .solver .step (
1020- state = state .step_from ,
1021- vector_field = vector_field ,
1022- dt = self .control .extract (state_control ),
1052+ state = state .step_from , vector_field = vector_field , dt = dt
10231053 )
10241054 # Normalise the error
10251055 u_proposed = self .ssm .stats .qoi (state_proposed .hidden )[0 ]
@@ -1028,8 +1058,11 @@ def body_fn(state: _RejectionState) -> _RejectionState:
10281058 error_power = _error_scale_and_normalize (error_estimate , u = u )
10291059
10301060 # Propose a new step
1031- state_control = self .control .apply (state_control , error_power = error_power )
1061+ dt , state_control = self .control .apply (
1062+ dt , state .control , error_power = error_power
1063+ )
10321064 return _RejectionState (
1065+ dt = dt ,
10331066 error_norm_proposed = error_power , # new
10341067 proposed = state_proposed , # new
10351068 control = state_control , # new
@@ -1044,17 +1077,16 @@ def _error_scale_and_normalize(error_estimate, *, u):
10441077 return error_norm_rel ** (- 1.0 / self .solver .error_contraction_rate )
10451078
10461079 def extract (s : _RejectionState ) -> _AdaState :
1047- num_steps = state0 .stats + 1
1048- return _AdaState (s .proposed , s .step_from , s .control , num_steps )
1080+ num_steps = state0 .stats + 1.0 # TODO: track step attempts as well
1081+ return _AdaState (s .dt , s . proposed , s .step_from , s .control , num_steps )
10491082
10501083 init_val = init (state0 )
10511084 state_new = control_flow .while_loop (cond_fn , body_fn , init_val )
10521085 return extract (state_new )
10531086
10541087 def extract_before_t1 (self , state : _AdaState ):
10551088 solution_solver = self .solver .extract (state .step_from )
1056- solution_control = self .control .extract (state .control )
1057- return solution_solver , solution_control , state .stats
1089+ return solution_solver , (state .dt , state .control ), state .stats
10581090
10591091 def extract_at_t1 (self , state : _AdaState ):
10601092 # todo: make the "at t1" decision inside interpolate(),
@@ -1063,37 +1095,47 @@ def extract_at_t1(self, state: _AdaState):
10631095 interp_from = state .interp_from , interp_to = state .step_from
10641096 )
10651097 state = _AdaState (
1066- interp .step_from , interp .interp_from , state .control , state .stats
1098+ state . dt , interp .step_from , interp .interp_from , state .control , state .stats
10671099 )
10681100
10691101 solution_solver = self .solver .extract (interp .interpolated )
1070- solution_control = self .control .extract (state .control )
1071- return state , (solution_solver , solution_control , state .stats )
1102+ return state , (solution_solver , (state .dt , state .control ), state .stats )
10721103
10731104 def extract_after_t1_via_interpolation (self , state : _AdaState , t ):
10741105 interp = self .solver .interpolate (
10751106 t , interp_from = state .interp_from , interp_to = state .step_from
10761107 )
10771108 state = _AdaState (
1078- interp .step_from , interp .interp_from , state .control , state .stats
1109+ state . dt , interp .step_from , interp .interp_from , state .control , state .stats
10791110 )
10801111
10811112 solution_solver = self .solver .extract (interp .interpolated )
1082- solution_control = self .control .extract (state .control )
1083- return state , (solution_solver , solution_control , state .stats )
1113+ return state , (solution_solver , (state .dt , state .control ), state .stats )
10841114
10851115 @staticmethod
10861116 def register_pytree_node ():
10871117 def _asolver_flatten (asolver ):
10881118 children = (asolver .atol , asolver .rtol )
1089- aux = (asolver .solver , asolver .control , asolver .norm_ord , asolver .ssm )
1119+ aux = (
1120+ asolver .solver ,
1121+ asolver .control ,
1122+ asolver .norm_ord ,
1123+ asolver .ssm ,
1124+ asolver .clip_dt ,
1125+ )
10901126 return children , aux
10911127
10921128 def _asolver_unflatten (aux , children ):
10931129 atol , rtol = children
1094- (slvr , control , norm_ord , ssm ) = aux
1130+ (slvr , control , norm_ord , ssm , clip_dt ) = aux
10951131 return _AdaSolver (
1096- slvr , atol = atol , rtol = rtol , control = control , norm_ord = norm_ord , ssm = ssm
1132+ slvr ,
1133+ atol = atol ,
1134+ rtol = rtol ,
1135+ control = control ,
1136+ norm_ord = norm_ord ,
1137+ ssm = ssm ,
1138+ clip_dt = clip_dt ,
10971139 )
10981140
10991141 tree_util .register_pytree_node (
@@ -1103,46 +1145,35 @@ def _asolver_unflatten(aux, children):
11031145
11041146_AdaSolver .register_pytree_node ()
11051147
1148+ T = TypeVar ("T" )
1149+
11061150
11071151@containers .dataclass
1108- class _Controller :
1152+ class _Controller ( Generic [ T ]) :
11091153 """Control algorithm."""
11101154
1111- init : Callable [[float ], Any ]
1155+ init : Callable [[float ], T ]
11121156 """Initialise the controller state."""
11131157
1114- clip : Callable [[Any , float , float ], Any ]
1115- """(Optionally) clip the current step to not exceed t1."""
1116-
1117- apply : Callable [[Any , NamedArg (float , "error_power" )], Any ]
1158+ apply : Callable [[float , T , NamedArg (float , "error_power" )], tuple [float , T ]]
11181159 r"""Propose a time-step $\Delta t$."""
11191160
1120- extract : Callable [[Any ], float ]
1121- """Extract the time-step from the controller state."""
1122-
11231161
11241162def control_proportional_integral (
11251163 * ,
1126- clip : bool = False ,
11271164 safety = 0.95 ,
11281165 factor_min = 0.2 ,
11291166 factor_max = 10.0 ,
11301167 power_integral_unscaled = 0.3 ,
11311168 power_proportional_unscaled = 0.4 ,
1132- ) -> _Controller :
1169+ ) -> _Controller [ float ] :
11331170 """Construct a proportional-integral-controller with time-clipping."""
11341171
1135- class PIState (containers .NamedTuple ):
1136- dt : float
1137- error_power_previously_accepted : float
1138-
1139- def init (dt : float , / ) -> PIState :
1140- return PIState (dt , 1.0 )
1172+ def init (_dt : float , / ) -> float :
1173+ return 1.0
11411174
1142- def apply (state : PIState , / , * , error_power ) -> PIState :
1175+ def apply (dt : float , error_power_prev : float , / , * , error_power ):
11431176 # error_power = error_norm ** (-1.0 / error_contraction_rate)
1144- dt_proposed , error_power_prev = state
1145-
11461177 a1 = error_power ** power_integral_unscaled
11471178 a2 = (error_power / error_power_prev ) ** power_proportional_unscaled
11481179 scale_factor_unclipped = safety * a1 * a2
@@ -1153,50 +1184,26 @@ def apply(state: PIState, /, *, error_power) -> PIState:
11531184 # >= 1.0 because error_power is 1/scaled_error_norm
11541185 error_power_prev = np .where (error_power >= 1.0 , error_power , error_power_prev )
11551186
1156- dt_proposed = scale_factor * dt_proposed
1157- return PIState (dt_proposed , error_power_prev )
1158-
1159- def extract (state : PIState , / ) -> float :
1160- dt_proposed , _error_norm_previously_accepted = state
1161- return dt_proposed
1162-
1163- if clip :
1164-
1165- def clip_fun (state : PIState , / , t , t1 ) -> PIState :
1166- dt_proposed , error_norm_previously_accepted = state
1167- dt = dt_proposed
1168- dt_clipped = np .minimum (dt , t1 - t )
1169- return PIState (dt_clipped , error_norm_previously_accepted )
1170-
1171- return _Controller (init = init , apply = apply , extract = extract , clip = clip_fun )
1187+ dt_proposed = scale_factor * dt
1188+ return dt_proposed , error_power_prev
11721189
1173- return _Controller (init = init , apply = apply , extract = extract , clip = lambda v , ** _kw : v )
1190+ return _Controller (init = init , apply = apply )
11741191
11751192
11761193def control_integral (
1177- * , clip = False , safety = 0.95 , factor_min = 0.2 , factor_max = 10.0
1178- ) -> _Controller :
1194+ * , safety = 0.95 , factor_min = 0.2 , factor_max = 10.0
1195+ ) -> _Controller [ None ] :
11791196 """Construct an integral-controller."""
11801197
1181- def init (dt , / ):
1182- return dt
1198+ def init (_dt , / ) -> None :
1199+ return None
11831200
1184- def apply (dt , / , * , error_power ):
1201+ def apply (dt , _state , / , * , error_power ):
11851202 # error_power = error_norm ** (-1.0 / error_contraction_rate)
11861203 scale_factor_unclipped = safety * error_power
11871204
11881205 scale_factor_clipped_min = np .minimum (scale_factor_unclipped , factor_max )
11891206 scale_factor = np .maximum (factor_min , scale_factor_clipped_min )
1190- return scale_factor * dt
1191-
1192- def extract (dt , / ):
1193- return dt
1194-
1195- if clip :
1196-
1197- def clip_fun (dt , / , t , t1 ):
1198- return np .minimum (dt , t1 - t )
1199-
1200- return _Controller (init = init , apply = apply , extract = extract , clip = clip_fun )
1207+ return scale_factor * dt , None
12011208
1202- return _Controller (init = init , apply = apply , extract = extract , clip = lambda v , ** _kw : v )
1209+ return _Controller (init = init , apply = apply )
0 commit comments