77See the tutorials for example use cases.
88"""
99
10- from probdiffeq .backend import (
11- containers ,
12- control_flow ,
13- func ,
14- linalg ,
15- np ,
16- tree ,
17- warnings ,
18- )
10+ from probdiffeq .backend import containers , flow , func , linalg , np , tree , warnings
1911from probdiffeq .backend .typing import Any , Array , Callable , Generic , Protocol , TypeVar
2012
2113T_contra = TypeVar ("T_contra" , contravariant = True )
@@ -147,7 +139,7 @@ def solve_adaptive_terminal_values(
147139 errorest ,
148140 control : Control | None = None ,
149141 clip_dt : bool = False ,
150- while_loop : Callable = control_flow .while_loop ,
142+ while_loop : Callable = flow .while_loop ,
151143) -> Callable [..., Solution ]:
152144 """Simulate the terminal values of an initial value problem."""
153145 # Turn off warnings because any solver goes for terminal values
@@ -178,7 +170,7 @@ def solve_adaptive_save_at(
178170 errorest ,
179171 control : Control | None = None ,
180172 clip_dt : bool = False ,
181- while_loop : Callable = control_flow .while_loop ,
173+ while_loop : Callable = flow .while_loop ,
182174 warn = True ,
183175) -> Callable [..., Solution ]:
184176 r"""Solve an initial value problem and return the solution at a pre-determined grid.
@@ -261,7 +253,7 @@ def body_fun(state: AdvanceState) -> AdvanceState:
261253 # Advance to one checkpoint after the other
262254 init = (solution0 , state )
263255 xs = save_at [1 :]
264- (_solution , _state ), solution = control_flow .scan (
256+ (_solution , _state ), solution = flow .scan (
265257 advance , init = init , xs = xs , reverse = False
266258 )
267259
@@ -281,7 +273,7 @@ def body_fn(s, dt):
281273
282274 t0 = grid [0 ]
283275 state0 = solver .init (t = t0 , u = u )
284- _ , result = control_flow .scan (body_fn , init = state0 , xs = np .diff (grid ))
276+ _ , result = flow .scan (body_fn , init = state0 , xs = np .diff (grid ))
285277
286278 return solver .userfriendly_output (solution0 = state0 , solution = result )
287279
@@ -426,14 +418,14 @@ def loop(
426418 # If t1 is in the future, enter the rejection loop (otherwise do nothing)
427419 is_before_t1 = state0 .step_from .t + eps < t1
428420 args = (state0 , t1 , atol , rtol , damp )
429- state = control_flow .cond (is_before_t1 , self .step , lambda s : s [0 ], args )
421+ state = flow .cond (is_before_t1 , self .step , lambda s : s [0 ], args )
430422
431423 # Interpolate
432424 is_before_t1 = state .step_from .t + eps < t1
433425 is_after_t1 = state .step_from .t > t1 + eps
434426 branch_idx = np .where (is_before_t1 , 0 , np .where (is_after_t1 , 1 , 2 ))
435427 options = (self .interp_skip , self .interp_beyond_t1 , self .interp_at_t1 )
436- return control_flow .switch (branch_idx , options , (state , t1 ))
428+ return flow .switch (branch_idx , options , (state , t1 ))
437429
438430 def step (self , s_and_t1_and_tols_and_damp ):
439431 """Do a rejection-loop step.
0 commit comments