44from probdiffeq .backend import (
55 containers ,
66 control_flow ,
7- functools ,
87 linalg ,
98 tree_array_util ,
109 tree_util ,
@@ -89,42 +88,21 @@ def solve_adaptive_terminal_values(
8988 vector_field , initial_condition , t0 , t1 , adaptive_solver , dt0 , * , ssm
9089) -> IVPSolution :
9190 """Simulate the terminal values of an initial value problem."""
92- save_at = np .asarray ([t1 ])
93- (_t , solution_save_at ), _ , num_steps = _solve_adaptive_save_at (
94- tree_util .Partial (vector_field ),
95- t0 ,
91+ save_at = np .asarray ([t0 , t1 ])
92+ solution = solve_adaptive_save_at (
93+ vector_field ,
9694 initial_condition ,
9795 save_at = save_at ,
9896 adaptive_solver = adaptive_solver ,
9997 dt0 = dt0 ,
100- )
101- # "squeeze"-type functionality (there is only a single state!)
102- squeeze_fun = functools .partial (np .squeeze_along_axis , axis = 0 )
103- solution_save_at = tree_util .tree_map (squeeze_fun , solution_save_at )
104- num_steps = tree_util .tree_map (squeeze_fun , num_steps )
105-
106- # I think the user expects marginals, so we compute them here
107- # todo: do this in IVPSolution.* methods?
108- posterior , output_scale = solution_save_at
109- marginals = posterior .init if isinstance (posterior , stats .MarkovSeq ) else posterior
110-
111- u = ssm .stats .qoi_from_sample (marginals .mean )
112- std = ssm .stats .standard_deviation (marginals )
113- u_std = ssm .stats .qoi_from_sample (std )
114- return IVPSolution (
115- t = t1 ,
116- u = u ,
117- u_std = u_std ,
11898 ssm = ssm ,
119- marginals = marginals ,
120- posterior = posterior ,
121- output_scale = output_scale ,
122- num_steps = num_steps ,
99+ warn = False , # Turn off warnings because any solver goes for terminal values
123100 )
101+ return tree_util .tree_map (lambda s : s [- 1 ], solution )
124102
125103
126104def solve_adaptive_save_at (
127- vector_field , initial_condition , save_at , adaptive_solver , dt0 , * , ssm
105+ vector_field , initial_condition , save_at , adaptive_solver , dt0 , * , ssm , warn = True
128106) -> IVPSolution :
129107 r"""Solve an initial value problem and return the solution at a pre-determined grid.
130108
@@ -152,7 +130,7 @@ def solve_adaptive_save_at(
152130 }
153131 ```
154132 """
155- if not adaptive_solver .solver .is_suitable_for_save_at :
133+ if not adaptive_solver .solver .is_suitable_for_save_at and warn :
156134 msg = (
157135 f"Strategy { adaptive_solver .solver } should not "
158136 f"be used in solve_adaptive_save_at. "
@@ -170,7 +148,7 @@ def solve_adaptive_save_at(
170148
171149 # I think the user expects the initial condition to be part of the state
172150 # (as well as marginals), so we compute those things here
173- posterior_t0 , * _ = initial_condition
151+ posterior_t0 = initial_condition . posterior
174152 posterior_save_at , output_scale = solution_save_at
175153 _tmp = _userfriendly_output (
176154 posterior = posterior_save_at , posterior_t0 = posterior_t0 , ssm = ssm
@@ -194,41 +172,37 @@ def solve_adaptive_save_at(
194172def _solve_adaptive_save_at (
195173 vector_field , t , initial_condition , * , save_at , adaptive_solver , dt0
196174):
197- advance_func = functools .partial (
198- _advance_and_interpolate ,
199- vector_field = vector_field ,
200- adaptive_solver = adaptive_solver ,
201- )
202-
203- state = adaptive_solver .init (t , initial_condition , dt = dt0 , num_steps = 0.0 )
204- _ , solution = control_flow .scan (advance_func , init = state , xs = save_at , reverse = False )
205- return solution
206-
207-
208- def _advance_and_interpolate (state , t_next , * , vector_field , adaptive_solver ):
209- # Advance until accepted.t >= t_next.
210- # Note: This could already be the case and we may not loop (just interpolate)
211- def cond_fun (s ):
212- # Terminate the loop if
213- # the difference from s.t to t_next is smaller than a constant factor
214- # (which is a "small" multiple of the current machine precision)
215- # or if s.t > t_next holds.
216- return s .step_from .t + 10 * np .finfo_eps (float ) < t_next
175+ def advance (state , t_next ):
176+ # Advance until accepted.t >= t_next.
177+ # Note: This could already be the case and we may not loop (just interpolate)
178+ def cond_fun (s ):
179+ # Terminate the loop if
180+ # the difference from s.t to t_next is smaller than a constant factor
181+ # (which is a "small" multiple of the current machine precision)
182+ # or if s.t > t_next holds.
183+ return s .step_from .t + adaptive_solver .eps < t_next
184+
185+ def body_fun (s ):
186+ return adaptive_solver .rejection_loop (
187+ s , vector_field = vector_field , t1 = t_next
188+ )
217189
218- def body_fun (s ):
219- return adaptive_solver .rejection_loop (s , vector_field = vector_field , t1 = t_next )
190+ state = control_flow .while_loop (cond_fun , body_fun , init = state )
220191
221- state = control_flow .while_loop (cond_fun , body_fun , init = state )
192+ # Either interpolate (t > t_next) or "finalise" (t == t_next)
193+ is_after_t1 = state .step_from .t > t_next + adaptive_solver .eps
194+ state , solution = control_flow .cond (
195+ is_after_t1 ,
196+ adaptive_solver .extract_after_t1 ,
197+ adaptive_solver .extract_at_t1 ,
198+ state ,
199+ t_next ,
200+ )
201+ return state , solution
222202
223- # Either interpolate (t > t_next) or "finalise" (t == t_next)
224- state , solution = control_flow .cond (
225- state .step_from .t > t_next + 10 * np .finfo_eps (float ),
226- adaptive_solver .extract_after_t1_via_interpolation ,
227- lambda s , _t : adaptive_solver .extract_at_t1 (s ),
228- state ,
229- t_next ,
230- )
231- return state , solution
203+ state = adaptive_solver .init (t , initial_condition , dt = dt0 , num_steps = 0.0 )
204+ _ , solution = control_flow .scan (advance , init = state , xs = save_at , reverse = False )
205+ return solution
232206
233207
234208def solve_adaptive_save_every_step (
@@ -264,7 +238,7 @@ def solve_adaptive_save_every_step(
264238 t = np .concatenate ((np .atleast_1d (t0 ), t ))
265239
266240 # I think the user expects marginals, so we compute them here
267- posterior_t0 , * _ = initial_condition
241+ posterior_t0 = initial_condition . posterior
268242 posterior , output_scale = solution_every_step
269243 _tmp = _userfriendly_output (posterior = posterior , posterior_t0 = posterior_t0 , ssm = ssm )
270244 marginals , posterior = _tmp
@@ -292,15 +266,16 @@ def _solution_generator(
292266 while state .step_from .t < t1 :
293267 state = adaptive_solver .rejection_loop (state , vector_field = vector_field , t1 = t1 )
294268
295- if state .step_from .t < t1 :
296- solution = adaptive_solver .extract_before_t1 (state )
269+ if state .step_from .t + adaptive_solver . eps < t1 :
270+ _ , solution = adaptive_solver .extract_before_t1 (state , t = t1 )
297271 yield solution
298272
299273 # Either interpolate (t > t_next) or "finalise" (t == t_next)
300- if state .step_from .t > t1 :
301- _ , solution = adaptive_solver .extract_after_t1_via_interpolation (state , t = t1 )
274+ is_after_t1 = state .step_from .t > t1 + adaptive_solver .eps
275+ if is_after_t1 :
276+ _ , solution = adaptive_solver .extract_after_t1 (state , t = t1 )
302277 else :
303- _ , solution = adaptive_solver .extract_at_t1 (state )
278+ _ , solution = adaptive_solver .extract_at_t1 (state , t = t1 )
304279
305280 yield solution
306281
@@ -321,7 +296,7 @@ def body_fn(s, dt):
321296 _t , (posterior , output_scale ) = solver .extract (result_state )
322297
323298 # I think the user expects marginals, so we compute them here
324- posterior_t0 , * _ = initial_condition
299+ posterior_t0 = initial_condition . posterior
325300 _tmp = _userfriendly_output (posterior = posterior , posterior_t0 = posterior_t0 , ssm = ssm )
326301 marginals , posterior = _tmp
327302
0 commit comments