Skip to content

Commit ed94830

Browse files
committed
Rename backend.control_flow to backend.flow to shorten
1 parent d3d5efa commit ed94830

7 files changed

Lines changed: 22 additions & 43 deletions

File tree

probdiffeq/ivpsolve.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,7 @@
77
See the tutorials for example use cases.
88
"""
99

10-
from probdiffeq.backend import (
11-
containers,
12-
control_flow,
13-
func,
14-
linalg,
15-
np,
16-
tree,
17-
warnings,
18-
)
10+
from probdiffeq.backend import containers, flow, func, linalg, np, tree, warnings
1911
from probdiffeq.backend.typing import Any, Array, Callable, Generic, Protocol, TypeVar
2012

2113
T_contra = TypeVar("T_contra", contravariant=True)
@@ -147,7 +139,7 @@ def solve_adaptive_terminal_values(
147139
errorest,
148140
control: Control | None = None,
149141
clip_dt: bool = False,
150-
while_loop: Callable = control_flow.while_loop,
142+
while_loop: Callable = flow.while_loop,
151143
) -> Callable[..., Solution]:
152144
"""Simulate the terminal values of an initial value problem."""
153145
# Turn off warnings because any solver goes for terminal values
@@ -178,7 +170,7 @@ def solve_adaptive_save_at(
178170
errorest,
179171
control: Control | None = None,
180172
clip_dt: bool = False,
181-
while_loop: Callable = control_flow.while_loop,
173+
while_loop: Callable = flow.while_loop,
182174
warn=True,
183175
) -> Callable[..., Solution]:
184176
r"""Solve an initial value problem and return the solution at a pre-determined grid.
@@ -261,7 +253,7 @@ def body_fun(state: AdvanceState) -> AdvanceState:
261253
# Advance to one checkpoint after the other
262254
init = (solution0, state)
263255
xs = save_at[1:]
264-
(_solution, _state), solution = control_flow.scan(
256+
(_solution, _state), solution = flow.scan(
265257
advance, init=init, xs=xs, reverse=False
266258
)
267259

@@ -281,7 +273,7 @@ def body_fn(s, dt):
281273

282274
t0 = grid[0]
283275
state0 = solver.init(t=t0, u=u)
284-
_, result = control_flow.scan(body_fn, init=state0, xs=np.diff(grid))
276+
_, result = flow.scan(body_fn, init=state0, xs=np.diff(grid))
285277

286278
return solver.userfriendly_output(solution0=state0, solution=result)
287279

@@ -426,14 +418,14 @@ def loop(
426418
# If t1 is in the future, enter the rejection loop (otherwise do nothing)
427419
is_before_t1 = state0.step_from.t + eps < t1
428420
args = (state0, t1, atol, rtol, damp)
429-
state = control_flow.cond(is_before_t1, self.step, lambda s: s[0], args)
421+
state = flow.cond(is_before_t1, self.step, lambda s: s[0], args)
430422

431423
# Interpolate
432424
is_before_t1 = state.step_from.t + eps < t1
433425
is_after_t1 = state.step_from.t > t1 + eps
434426
branch_idx = np.where(is_before_t1, 0, np.where(is_after_t1, 1, 2))
435427
options = (self.interp_skip, self.interp_beyond_t1, self.interp_at_t1)
436-
return control_flow.switch(branch_idx, options, (state, t1))
428+
return flow.switch(branch_idx, options, (state, t1))
437429

438430
def step(self, s_and_t1_and_tols_and_damp):
439431
"""Do a rejection-loop step.

probdiffeq/probdiffeq.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,7 @@
33
See the tutorials for example use cases.
44
"""
55

6-
from probdiffeq.backend import (
7-
containers,
8-
control_flow,
9-
func,
10-
linalg,
11-
np,
12-
random,
13-
special,
14-
tree,
15-
)
6+
from probdiffeq.backend import containers, flow, func, linalg, np, random, special, tree
167
from probdiffeq.backend.typing import (
178
Any,
189
Array,
@@ -324,7 +315,7 @@ def step(x, cond):
324315
return extrapolated, extrapolated
325316

326317
init, xs = markov_seq.marginal, markov_seq.conditional
327-
_, marginals = control_flow.scan(step, init=init, xs=xs, reverse=reverse)
318+
_, marginals = flow.scan(step, init=init, xs=xs, reverse=reverse)
328319

329320
if reverse:
330321
# Append the terminal marginal to the computed ones
@@ -386,9 +377,7 @@ def body_fun(samp_prev, conditionals_and_base_samples):
386377

387378
# Loop over backward models and the remaining base samples
388379
xs = (markov_seq.conditional, base_sample_body)
389-
_, samples = control_flow.scan(
390-
body_fun, init=init_sample, xs=xs, reverse=reverse
391-
)
380+
_, samples = flow.scan(body_fun, init=init_sample, xs=xs, reverse=reverse)
392381

393382
if reverse:
394383
samples = np.concatenate([samples, init_sample[None, ...]])

probdiffeq/taylor.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
See the tutorials for example use cases.
88
"""
99

10-
from probdiffeq.backend import control_flow, func, np, ode, tree
10+
from probdiffeq.backend import flow, func, np, ode, tree
1111
from probdiffeq.backend.typing import Array, ArrayLike, Callable, Sequence
1212
from probdiffeq.util import filter_util
1313

@@ -126,9 +126,7 @@ def body(tcoeffs, _):
126126
return taylor_coeffs
127127

128128
# Compute all coefficients with scan().
129-
taylor_coeffs, _ = control_flow.scan(
130-
body, init=taylor_coeffs, xs=None, length=num - 1
131-
)
129+
taylor_coeffs, _ = flow.scan(body, init=taylor_coeffs, xs=None, length=num - 1)
132130
return taylor_coeffs
133131

134132

@@ -329,7 +327,7 @@ def body_fun(cs_padded, i_and_fx_i):
329327
cs_padded = np.stack(cs_padded)
330328

331329
xs = [np.arange(0, len(fx[deg : 2 * deg])), fx[deg : 2 * deg]]
332-
cs_padded, _ = control_flow.scan(body_fun, xs=xs, init=cs_padded)
330+
cs_padded, _ = flow.scan(body_fun, xs=xs, init=cs_padded)
333331

334332
taylor_coefficients.extend(cs_padded)
335333
return taylor_coefficients

probdiffeq/util/cholesky_util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
2121
"""
2222

23-
from probdiffeq.backend import control_flow, linalg, np, tree
23+
from probdiffeq.backend import flow, linalg, np, tree
2424

2525

2626
def revert_conditional_noisefree(R_X_F, R_X):
@@ -187,7 +187,7 @@ def f_body(idx, f):
187187
)
188188
return f.at[idx].set(val)
189189

190-
f = control_flow.fori_loop(1, n, f_body, f)
190+
f = flow.fori_loop(1, n, f_body, f)
191191
f = 1.0 / f
192192

193193
U = np.eye(n)
@@ -204,10 +204,10 @@ def inner_body(k, g):
204204
newval = (g[i + 1] / denom) * factor
205205
return g.at[i].set(newval)
206206

207-
g = control_flow.fori_loop(0, j_idx, inner_body, g)
207+
g = flow.fori_loop(0, j_idx, inner_body, g)
208208
return U.at[:, j_idx].set(g)
209209

210-
U = control_flow.fori_loop(1, n, body_j, U)
210+
U = flow.fori_loop(1, n, body_j, U)
211211

212212
# scale columns: U = U .* (dr * f_row)
213213
U = U * (dr[:, None] * f[None, :])

probdiffeq/util/filter_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Mostly **discrete** filtering and smoothing.
44
"""
55

6-
from probdiffeq.backend import containers, control_flow, tree
6+
from probdiffeq.backend import containers, flow, tree
77
from probdiffeq.backend.typing import Any
88

99

@@ -19,7 +19,7 @@ def estimate_fwd(data, /, init, prior_transitions, observation_model, *, estimat
1919
idx_or_slice = slice(1, len(data), 1)
2020
information = _select((data, observation_model), idx_or_slice=idx_or_slice)
2121
xs = (prior_transitions, *information)
22-
return control_flow.scan(step, init=init, xs=xs, reverse=False)
22+
return flow.scan(step, init=init, xs=xs, reverse=False)
2323

2424

2525
def estimate_rev(data, /, init, prior_transitions, observation_model, *, estimator):
@@ -33,7 +33,7 @@ def estimate_rev(data, /, init, prior_transitions, observation_model, *, estimat
3333
# Scan over the remaining data points
3434
information = _select((data, observation_model), idx_or_slice=slice(0, -1, 1))
3535
xs = (prior_transitions, *information)
36-
return control_flow.scan(step, init=init, xs=xs, reverse=True)
36+
return flow.scan(step, init=init, xs=xs, reverse=True)
3737

3838

3939
def fixedpointsmoother_precon(*, ssm):

probdiffeq/util/test_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Test utilities."""
22

33
import probdiffeq.ivpsolve
4-
from probdiffeq.backend import control_flow, func, tree, warnings
4+
from probdiffeq.backend import flow, func, tree, warnings
55
from probdiffeq.backend.typing import TypeVar
66

77
T = TypeVar("T")
@@ -35,7 +35,7 @@ def solve_adaptive_save_every_step(solver, errorest, control=None, clip_dt=False
3535
# We do not expose this option to the user
3636
# because we do not want to suggest that this function
3737
# uses meaningful looping to begin with.
38-
while_loop=control_flow.while_loop,
38+
while_loop=flow.while_loop,
3939
)
4040

4141
def solve(

0 commit comments

Comments
 (0)