Skip to content

Commit 4f88725

Browse files
committed
Rename variables to be more similar to previous versions
1 parent 6143a43 commit 4f88725

3 files changed

Lines changed: 20 additions & 31 deletions

File tree

docs/examples_basic/conditioning_on_zero_residual.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def vector_field(y, t): # noqa: ARG001
5555
NUM_DERIVATIVES = 2
5656
tcoeffs_like = [u0] * (NUM_DERIVATIVES + 1)
5757
ts = jnp.linspace(t0, t1, num=500, endpoint=True)
58-
init_raw, transitions, ssm = ivpsolvers.prior_wiener_integrated_discretised(
58+
init_raw, transitions, ssm = ivpsolvers.prior_wiener_integrated_discrete(
5959
ts, tcoeffs_like=tcoeffs_like, output_scale=100.0, ssm_fact="dense"
6060
)
6161

probdiffeq/ivpsolve.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,13 @@ def _sol_unflatten(aux, children):
8585

8686

8787
def solve_adaptive_terminal_values(
88-
vector_field, initial_condition, t0, t1, adaptive_solver, dt0, *, ssm
88+
vector_field, ssm_init, t0, t1, adaptive_solver, dt0, *, ssm
8989
) -> IVPSolution:
9090
"""Simulate the terminal values of an initial value problem."""
9191
save_at = np.asarray([t0, t1])
9292
solution = solve_adaptive_save_at(
9393
vector_field,
94-
initial_condition,
94+
ssm_init,
9595
save_at=save_at,
9696
adaptive_solver=adaptive_solver,
9797
dt0=dt0,
@@ -102,7 +102,7 @@ def solve_adaptive_terminal_values(
102102

103103

104104
def solve_adaptive_save_at(
105-
vector_field, initial_condition, save_at, adaptive_solver, dt0, *, ssm, warn=True
105+
vector_field, ssm_init, save_at, adaptive_solver, dt0, *, ssm, warn=True
106106
) -> IVPSolution:
107107
r"""Solve an initial value problem and return the solution at a pre-determined grid.
108108
@@ -140,17 +140,16 @@ def solve_adaptive_save_at(
140140
(_t, solution_save_at), _, num_steps = _solve_adaptive_save_at(
141141
tree_util.Partial(vector_field),
142142
save_at[0],
143-
initial_condition,
143+
ssm_init,
144144
save_at=save_at[1:],
145145
adaptive_solver=adaptive_solver,
146146
dt0=dt0,
147147
)
148148

149149
# I think the user expects the initial condition to be part of the state
150150
# (as well as marginals), so we compute those things here
151-
init_t0 = initial_condition
152151
posterior_save_at, output_scale = solution_save_at
153-
_tmp = _userfriendly_output(posterior=posterior_save_at, init_t0=init_t0, ssm=ssm)
152+
_tmp = _userfriendly_output(posterior=posterior_save_at, ssm_init=ssm_init, ssm=ssm)
154153
marginals, posterior = _tmp
155154
u = ssm.stats.qoi_from_sample(marginals.mean)
156155
std = ssm.stats.standard_deviation(marginals)
@@ -168,7 +167,7 @@ def solve_adaptive_save_at(
168167

169168

170169
def _solve_adaptive_save_at(
171-
vector_field, t, initial_condition, *, save_at, adaptive_solver, dt0
170+
vector_field, t, ssm_init, *, save_at, adaptive_solver, dt0
172171
):
173172
def advance(state, t_next):
174173
# Advance until accepted.t >= t_next.
@@ -198,13 +197,13 @@ def body_fun(s):
198197
)
199198
return state, solution
200199

201-
state = adaptive_solver.init(t, initial_condition, dt=dt0, num_steps=0.0)
200+
state = adaptive_solver.init(t, ssm_init, dt=dt0, num_steps=0.0)
202201
_, solution = control_flow.scan(advance, init=state, xs=save_at, reverse=False)
203202
return solution
204203

205204

206205
def solve_adaptive_save_every_step(
207-
vector_field, initial_condition, t0, t1, adaptive_solver, dt0, *, ssm
206+
vector_field, ssm_init, t0, t1, adaptive_solver, dt0, *, ssm
208207
) -> IVPSolution:
209208
"""Solve an initial value problem and save every step.
210209
@@ -223,7 +222,7 @@ def solve_adaptive_save_every_step(
223222
generator = _solution_generator(
224223
tree_util.Partial(vector_field),
225224
t0,
226-
initial_condition,
225+
ssm_init,
227226
t1=t1,
228227
adaptive_solver=adaptive_solver,
229228
dt0=dt0,
@@ -236,9 +235,8 @@ def solve_adaptive_save_every_step(
236235
t = np.concatenate((np.atleast_1d(t0), t))
237236

238237
# I think the user expects marginals, so we compute them here
239-
init_t0 = initial_condition
240238
posterior, output_scale = solution_every_step
241-
_tmp = _userfriendly_output(posterior=posterior, init_t0=init_t0, ssm=ssm)
239+
_tmp = _userfriendly_output(posterior=posterior, ssm_init=ssm_init, ssm=ssm)
242240
marginals, posterior = _tmp
243241

244242
u = ssm.stats.qoi_from_sample(marginals.mean)
@@ -256,10 +254,8 @@ def solve_adaptive_save_every_step(
256254
)
257255

258256

259-
def _solution_generator(
260-
vector_field, t, initial_condition, *, dt0, t1, adaptive_solver
261-
):
262-
state = adaptive_solver.init(t, initial_condition, dt=dt0, num_steps=0)
257+
def _solution_generator(vector_field, t, ssm_init, *, dt0, t1, adaptive_solver):
258+
state = adaptive_solver.init(t, ssm_init, dt=dt0, num_steps=0)
263259

264260
while state.step_from.t < t1:
265261
state = adaptive_solver.rejection_loop(state, vector_field=vector_field, t1=t1)
@@ -278,9 +274,7 @@ def _solution_generator(
278274
yield solution
279275

280276

281-
def solve_fixed_grid(
282-
vector_field, initial_condition, grid, solver, *, ssm
283-
) -> IVPSolution:
277+
def solve_fixed_grid(vector_field, ssm_init, grid, solver, *, ssm) -> IVPSolution:
284278
"""Solve an initial value problem on a fixed, pre-determined grid."""
285279
# Compute the solution
286280

@@ -289,13 +283,12 @@ def body_fn(s, dt):
289283
return s_new, s_new
290284

291285
t0 = grid[0]
292-
state0 = solver.init(t0, initial_condition)
286+
state0 = solver.init(t0, ssm_init)
293287
_, result_state = control_flow.scan(body_fn, init=state0, xs=np.diff(grid))
294288
_t, (posterior, output_scale) = solver.extract(result_state)
295289

296290
# I think the user expects marginals, so we compute them here
297-
init_t0 = initial_condition
298-
_tmp = _userfriendly_output(posterior=posterior, init_t0=init_t0, ssm=ssm)
291+
_tmp = _userfriendly_output(posterior=posterior, ssm_init=ssm_init, ssm=ssm)
299292
marginals, posterior = _tmp
300293

301294
u = ssm.stats.qoi_from_sample(marginals.mean)
@@ -313,7 +306,7 @@ def body_fn(s, dt):
313306
)
314307

315308

316-
def _userfriendly_output(*, posterior, init_t0, ssm):
309+
def _userfriendly_output(*, posterior, ssm_init, ssm):
317310
if isinstance(posterior, stats.MarkovSeq):
318311
# Compute marginals
319312
posterior_no_filter_marginals = stats.markov_select_terminal(posterior)
@@ -326,10 +319,10 @@ def _userfriendly_output(*, posterior, init_t0, ssm):
326319
marginals = tree_array_util.tree_append(marginals, marginal_t1)
327320

328321
# Prepend the marginal at t1 to the inits
329-
init = tree_array_util.tree_prepend(init_t0, posterior.init)
322+
init = tree_array_util.tree_prepend(ssm_init, posterior.init)
330323
posterior = stats.MarkovSeq(init=init, conditional=posterior.conditional)
331324
else:
332-
posterior = tree_array_util.tree_prepend(init_t0, posterior)
325+
posterior = tree_array_util.tree_prepend(ssm_init, posterior)
333326
marginals = posterior
334327
return marginals, posterior
335328

probdiffeq/ivpsolvers.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,8 @@ def prior_wiener_integrated(tcoeffs, *, ssm_fact: str, output_scale=None):
3636
init = ssm.normal.from_tcoeffs(tcoeffs)
3737
return init, discretize, ssm
3838

39-
# output_scale_calib = np.ones_like(ssm.prototypes.output_scale())
40-
# prior = _MarkovProcess(tcoeffs, output_scale_calib, discretize=discretize)
41-
# return prior, ssm
4239

43-
44-
def prior_wiener_integrated_discretised(
40+
def prior_wiener_integrated_discrete(
4541
ts, *, tcoeffs_like, ssm_fact: str, output_scale=None
4642
):
4743
"""Compute a time-discretized, multiply-integrated Wiener process."""

0 commit comments

Comments
 (0)