@@ -85,13 +85,13 @@ def _sol_unflatten(aux, children):
8585
8686
8787def solve_adaptive_terminal_values (
88- vector_field , initial_condition , t0 , t1 , adaptive_solver , dt0 , * , ssm
88+ vector_field , ssm_init , t0 , t1 , adaptive_solver , dt0 , * , ssm
8989) -> IVPSolution :
9090 """Simulate the terminal values of an initial value problem."""
9191 save_at = np .asarray ([t0 , t1 ])
9292 solution = solve_adaptive_save_at (
9393 vector_field ,
94- initial_condition ,
94+ ssm_init ,
9595 save_at = save_at ,
9696 adaptive_solver = adaptive_solver ,
9797 dt0 = dt0 ,
@@ -102,7 +102,7 @@ def solve_adaptive_terminal_values(
102102
103103
104104def solve_adaptive_save_at (
105- vector_field , initial_condition , save_at , adaptive_solver , dt0 , * , ssm , warn = True
105+ vector_field , ssm_init , save_at , adaptive_solver , dt0 , * , ssm , warn = True
106106) -> IVPSolution :
107107 r"""Solve an initial value problem and return the solution at a pre-determined grid.
108108
@@ -140,17 +140,16 @@ def solve_adaptive_save_at(
140140 (_t , solution_save_at ), _ , num_steps = _solve_adaptive_save_at (
141141 tree_util .Partial (vector_field ),
142142 save_at [0 ],
143- initial_condition ,
143+ ssm_init ,
144144 save_at = save_at [1 :],
145145 adaptive_solver = adaptive_solver ,
146146 dt0 = dt0 ,
147147 )
148148
149149 # I think the user expects the initial condition to be part of the state
150150 # (as well as marginals), so we compute those things here
151- init_t0 = initial_condition
152151 posterior_save_at , output_scale = solution_save_at
153- _tmp = _userfriendly_output (posterior = posterior_save_at , init_t0 = init_t0 , ssm = ssm )
152+ _tmp = _userfriendly_output (posterior = posterior_save_at , ssm_init = ssm_init , ssm = ssm )
154153 marginals , posterior = _tmp
155154 u = ssm .stats .qoi_from_sample (marginals .mean )
156155 std = ssm .stats .standard_deviation (marginals )
@@ -168,7 +167,7 @@ def solve_adaptive_save_at(
168167
169168
170169def _solve_adaptive_save_at (
171- vector_field , t , initial_condition , * , save_at , adaptive_solver , dt0
170+ vector_field , t , ssm_init , * , save_at , adaptive_solver , dt0
172171):
173172 def advance (state , t_next ):
174173 # Advance until accepted.t >= t_next.
@@ -198,13 +197,13 @@ def body_fun(s):
198197 )
199198 return state , solution
200199
201- state = adaptive_solver .init (t , initial_condition , dt = dt0 , num_steps = 0.0 )
200+ state = adaptive_solver .init (t , ssm_init , dt = dt0 , num_steps = 0.0 )
202201 _ , solution = control_flow .scan (advance , init = state , xs = save_at , reverse = False )
203202 return solution
204203
205204
206205def solve_adaptive_save_every_step (
207- vector_field , initial_condition , t0 , t1 , adaptive_solver , dt0 , * , ssm
206+ vector_field , ssm_init , t0 , t1 , adaptive_solver , dt0 , * , ssm
208207) -> IVPSolution :
209208 """Solve an initial value problem and save every step.
210209
@@ -223,7 +222,7 @@ def solve_adaptive_save_every_step(
223222 generator = _solution_generator (
224223 tree_util .Partial (vector_field ),
225224 t0 ,
226- initial_condition ,
225+ ssm_init ,
227226 t1 = t1 ,
228227 adaptive_solver = adaptive_solver ,
229228 dt0 = dt0 ,
@@ -236,9 +235,8 @@ def solve_adaptive_save_every_step(
236235 t = np .concatenate ((np .atleast_1d (t0 ), t ))
237236
238237 # I think the user expects marginals, so we compute them here
239- init_t0 = initial_condition
240238 posterior , output_scale = solution_every_step
241- _tmp = _userfriendly_output (posterior = posterior , init_t0 = init_t0 , ssm = ssm )
239+ _tmp = _userfriendly_output (posterior = posterior , ssm_init = ssm_init , ssm = ssm )
242240 marginals , posterior = _tmp
243241
244242 u = ssm .stats .qoi_from_sample (marginals .mean )
@@ -256,10 +254,8 @@ def solve_adaptive_save_every_step(
256254 )
257255
258256
259- def _solution_generator (
260- vector_field , t , initial_condition , * , dt0 , t1 , adaptive_solver
261- ):
262- state = adaptive_solver .init (t , initial_condition , dt = dt0 , num_steps = 0 )
257+ def _solution_generator (vector_field , t , ssm_init , * , dt0 , t1 , adaptive_solver ):
258+ state = adaptive_solver .init (t , ssm_init , dt = dt0 , num_steps = 0 )
263259
264260 while state .step_from .t < t1 :
265261 state = adaptive_solver .rejection_loop (state , vector_field = vector_field , t1 = t1 )
@@ -278,9 +274,7 @@ def _solution_generator(
278274 yield solution
279275
280276
281- def solve_fixed_grid (
282- vector_field , initial_condition , grid , solver , * , ssm
283- ) -> IVPSolution :
277+ def solve_fixed_grid (vector_field , ssm_init , grid , solver , * , ssm ) -> IVPSolution :
284278 """Solve an initial value problem on a fixed, pre-determined grid."""
285279 # Compute the solution
286280
@@ -289,13 +283,12 @@ def body_fn(s, dt):
289283 return s_new , s_new
290284
291285 t0 = grid [0 ]
292- state0 = solver .init (t0 , initial_condition )
286+ state0 = solver .init (t0 , ssm_init )
293287 _ , result_state = control_flow .scan (body_fn , init = state0 , xs = np .diff (grid ))
294288 _t , (posterior , output_scale ) = solver .extract (result_state )
295289
296290 # I think the user expects marginals, so we compute them here
297- init_t0 = initial_condition
298- _tmp = _userfriendly_output (posterior = posterior , init_t0 = init_t0 , ssm = ssm )
291+ _tmp = _userfriendly_output (posterior = posterior , ssm_init = ssm_init , ssm = ssm )
299292 marginals , posterior = _tmp
300293
301294 u = ssm .stats .qoi_from_sample (marginals .mean )
@@ -313,7 +306,7 @@ def body_fn(s, dt):
313306 )
314307
315308
316- def _userfriendly_output (* , posterior , init_t0 , ssm ):
309+ def _userfriendly_output (* , posterior , ssm_init , ssm ):
317310 if isinstance (posterior , stats .MarkovSeq ):
318311 # Compute marginals
319312 posterior_no_filter_marginals = stats .markov_select_terminal (posterior )
@@ -326,10 +319,10 @@ def _userfriendly_output(*, posterior, init_t0, ssm):
326319 marginals = tree_array_util .tree_append (marginals , marginal_t1 )
327320
328321 # Prepend the marginal at t1 to the inits
329- init = tree_array_util .tree_prepend (init_t0 , posterior .init )
322+ init = tree_array_util .tree_prepend (ssm_init , posterior .init )
330323 posterior = stats .MarkovSeq (init = init , conditional = posterior .conditional )
331324 else :
332- posterior = tree_array_util .tree_prepend (init_t0 , posterior )
325+ posterior = tree_array_util .tree_prepend (ssm_init , posterior )
333326 marginals = posterior
334327 return marginals , posterior
335328
0 commit comments