Skip to content

Commit 6e7acf0

Browse files
authored
Improve the vector-field API: Move vector-fields back into constraints, read off ode_order from the Callable, and more (#851)
* Begin automating the ode_order detection * Update the ivpsolve-tests * Fix tests * Update some notebooks * Update most notebooks * Update more notebooks (and fail all notebooks on NaN detection) * Complete updating the benchmarks * Reenable positional-or-keyword arguments in vector fields * Improve docs * Remove NaN detection from blackjax example because somewhere, blackjax uses NaNs
1 parent a41affc commit 6e7acf0

38 files changed

Lines changed: 472 additions & 365 deletions

docs/examples_advanced/equinox_while_loop.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,15 @@
2626

2727
from probdiffeq import ivpsolve, probdiffeq, taylor
2828

29+
# Fail this notebook on NaN detection (to catch those in the CI)
30+
jax.config.update("jax_debug_nans", True)
31+
2932

3033
def solution_routine(while_loop):
3134
"""Construct a parameter-to-solution function and an initial value."""
3235

3336
@jax.jit
34-
def vf(y, *, t): # noqa: ARG001
37+
def vf(y, /, *, t): # noqa: ARG001
3538
"""Evaluate the vector field."""
3639
return 0.5 * y * (1 - y)
3740

@@ -40,12 +43,10 @@ def vf(y, *, t): # noqa: ARG001
4043

4144
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
4245
init, ibm, ssm = probdiffeq.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
43-
ts0 = probdiffeq.constraint_ode_ts0(ode_order=1, ssm=ssm)
46+
ts0 = probdiffeq.constraint_ode_ts0(vf, ssm=ssm)
4447

4548
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
46-
solver = probdiffeq.solver(
47-
vf, strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm
48-
)
49+
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm)
4950
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
5051
solve_adaptive = ivpsolve.solve_adaptive_terminal_values(
5152
solver=solver, errorest=errorest, while_loop=while_loop

docs/examples_advanced/neural_ode.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525

2626
from probdiffeq import ivpsolve, probdiffeq
2727

28+
# Fail this notebook on NaN detection (to catch those in the CI)
29+
jax.config.update("jax_debug_nans", True)
30+
2831

2932
def main(num_data=100, epochs=500, print_every=50, hidden=(20,), lr=0.2):
3033
"""Train a neural ODE using diffusion tempering."""
@@ -96,7 +99,7 @@ def vf_neural_ode(*, hidden: tuple, t0: float, t1: float):
9699
u0 = jnp.asarray([0.0])
97100

98101
@jax.jit
99-
def vf(y, *, t, p):
102+
def vf(y, /, *, t, p):
100103
"""Evaluate the neural ODE vector field."""
101104
y_and_t = jnp.concatenate([y, t[None]])
102105
return mlp(p, y_and_t)
@@ -167,14 +170,14 @@ def loss(
167170
init, ibm, ssm = probdiffeq.prior_wiener_integrated(
168171
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
169172
)
170-
ts0 = probdiffeq.constraint_ode_ts0(ssm=ssm)
173+
174+
def vf_p(y, /, *, t):
175+
return vf(y, t=t, p=p)
176+
177+
ts0 = probdiffeq.constraint_ode_ts0(vf_p, ssm=ssm)
171178
strategy = probdiffeq.strategy_smoother_fixedinterval(ssm=ssm)
172179
solver_ts0 = probdiffeq.solver(
173-
lambda *a, **kw: vf(*a, **kw, p=p),
174-
strategy=strategy,
175-
prior=ibm,
176-
constraint=ts0,
177-
ssm=ssm,
180+
strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm
178181
)
179182

180183
# Solve

docs/examples_advanced/parameter_estimation_blackjax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@
153153

154154

155155
@jax.jit
156-
def vf(y, *, t): # noqa: ARG001
156+
def vf(y, /, *, t): # noqa: ARG001
157157
"""Evaluate the Lotka-Volterra vector field."""
158158
return f(y, *f_args)
159159

@@ -180,9 +180,9 @@ def plot_solution(t, u, *, ax, marker=".", **plotting_kwargs):
180180
init, ibm, ssm = probdiffeq.prior_wiener_integrated(
181181
tcoeffs, output_scale=10.0, ssm_fact="isotropic"
182182
)
183-
ts0 = probdiffeq.constraint_ode_ts0(ssm=ssm)
183+
ts0 = probdiffeq.constraint_ode_ts0(vf, ssm=ssm)
184184
strategy = probdiffeq.strategy_filter(ssm=ssm)
185-
solver = probdiffeq.solver(vf, strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm)
185+
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm)
186186

187187

188188
@jax.jit

docs/examples_advanced/parameter_estimation_optax.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
if not backend.has_been_selected:
3737
backend.select("jax") # ivp examples in jax
3838

39+
# Fail this notebook on NaN detection (to catch those in the CI)
40+
jax.config.update("jax_debug_nans", True)
41+
42+
3943
# -
4044

4145

@@ -64,15 +68,13 @@ def solve(p):
6468
init, ibm, ssm = probdiffeq.prior_wiener_integrated(
6569
tcoeffs, output_scale=10.0, ssm_fact="isotropic"
6670
)
67-
ts0 = probdiffeq.constraint_ode_ts0(ssm=ssm)
71+
72+
def vf_p(y, /, *, t):
73+
return vf(y, t=t, p=p)
74+
75+
ts0 = probdiffeq.constraint_ode_ts0(vf_p, ssm=ssm)
6876
strategy = probdiffeq.strategy_smoother_fixedinterval(ssm=ssm)
69-
solver = probdiffeq.solver(
70-
jax.jit(lambda y, t: vf(y, t, p=p)),
71-
strategy=strategy,
72-
prior=ibm,
73-
constraint=ts0,
74-
ssm=ssm,
75-
)
77+
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm)
7678
solve = ivpsolve.solve_fixed_grid(solver=solver)
7779
return solve(init, grid=ts)
7880

docs/examples_advanced/solve_pde.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,32 +25,32 @@
2525
import jax.numpy as jnp
2626
import matplotlib.pyplot as plt
2727

28-
from probdiffeq import ivpsolve, probdiffeq, taylor
28+
from probdiffeq import ivpsolve, probdiffeq
2929

30-
jax.config.update("jax_enable_x64", True)
30+
# Fail this notebook on NaN detection (to catch those in the CI)
31+
jax.config.update("jax_debug_nans", True)
3132

3233

3334
def main():
3435
"""Simulate a PDE."""
3536
key = jax.random.PRNGKey(1)
3637
f, (u0,), (t0, t1) = fhn_2d(key, num=40, t1=10.0)
3738

38-
@jax.jit
39-
def vf(y, *, t): # noqa: ARG001
39+
def vf(y, /, *, t): # noqa: ARG001
4040
"""Evaluate the dynamics of the PDE."""
4141
return f(y)
4242

4343
print("Problem dimension:", u0.size)
4444

4545
# Set up a state-space model
46-
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
46+
tcoeffs = [u0, vf(u0, t=t0)]
4747
init, ibm, ssm = probdiffeq.prior_wiener_integrated(tcoeffs, ssm_fact="blockdiag")
4848

4949
# Build a solver
50-
ts = probdiffeq.constraint_ode_ts1(ssm=ssm)
50+
ts = probdiffeq.constraint_ode_ts1(vf, ssm=ssm)
5151
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
5252
solver = probdiffeq.solver_dynamic(
53-
vf, ssm=ssm, strategy=strategy, prior=ibm, constraint=ts
53+
ssm=ssm, strategy=strategy, prior=ibm, constraint=ts
5454
)
5555
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
5656

docs/examples_basic/conditioning_on_zero_residual.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
"""Demonstrate how probabilistic solvers work via conditioning on constraints."""
2424

25+
import functools
26+
2527
import jax
2628
import jax.numpy as jnp
2729
import matplotlib.pyplot as plt
@@ -32,6 +34,11 @@
3234
if not backend.has_been_selected:
3335
backend.select("jax") # ivp examples in jax
3436

37+
38+
# Fail this notebook on NaN detection (to catch those in the CI)
39+
jax.config.update("jax_debug_nans", True)
40+
41+
3542
# -
3643

3744
# Create an ODE problem.
@@ -40,8 +47,9 @@
4047

4148

4249
@jax.jit
43-
def vector_field(y, t): # noqa: ARG001
50+
def vector_field(y, /, *, t):
4451
"""Evaluate the logistic ODE vector field."""
52+
del t
4553
return 10.0 * y * (2.0 - y)
4654

4755

@@ -78,11 +86,9 @@ def vector_field(y, t): # noqa: ARG001
7886
init, ibm, ssm = probdiffeq.prior_wiener_integrated(
7987
tcoeffs, output_scale=1.0, ssm_fact="dense"
8088
)
81-
ts1 = probdiffeq.constraint_ode_ts1(ssm=ssm)
89+
ts1 = probdiffeq.constraint_ode_ts1(vector_field, ssm=ssm)
8290
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
83-
solver = probdiffeq.solver(
84-
vector_field, strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
85-
)
91+
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm)
8692
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
8793

8894
dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,))
@@ -126,7 +132,8 @@ def vector_field(y, t): # noqa: ARG001
126132

127133
def residual(x, t):
128134
"""Evaluate the ODE residual."""
129-
return x[1] - jax.vmap(jax.vmap(vector_field), in_axes=(0, None))(x[0], t)
135+
vf_wrapped = functools.partial(vector_field, t=t)
136+
return x[1] - jax.vmap(jax.vmap(vf_wrapped))(x[0])
130137

131138

132139
residual_prior = residual(samples_prior, ts)

docs/examples_basic/custom_information_operator.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333

3434
from probdiffeq import ivpsolve, probdiffeq
3535

36+
# Fail this notebook on NaN detection (to catch those in the CI)
37+
jax.config.update("jax_debug_nans", True)
38+
39+
3640
# -
3741

3842

@@ -43,7 +47,7 @@
4347

4448

4549
@jax.jit
46-
def vf_1st(y, *, t):
50+
def vf_1st(y, /, *, t):
4751
"""Evaluate the harmonic oscillator dynamics."""
4852
u, du = jnp.split(y, 2)
4953
return jnp.concatenate([du, vf_2nd(u, du, t=t)])
@@ -90,10 +94,10 @@ def hamiltonian_2nd(u, du):
9094
tcoeffs = [u0_1st, zeros, zeros]
9195
tcoeffs_std = [zeros, ones, ones]
9296
init, ibm, ssm = probdiffeq.prior_wiener_integrated(tcoeffs, tcoeffs_std=tcoeffs_std)
93-
ts1 = probdiffeq.constraint_ode_ts1(ssm=ssm)
97+
ts1 = probdiffeq.constraint_ode_ts1(vf_1st, ssm=ssm)
9498
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
9599
solver_1st = probdiffeq.solver_mle(
96-
vf_1st, strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
100+
strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
97101
)
98102
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
99103
solve = ivpsolve.solve_adaptive_save_at(solver=solver_1st, errorest=errorest)
@@ -115,9 +119,9 @@ def hamiltonian_2nd(u, du):
115119
# +
116120

117121

118-
def root(vf, u, du, ddu):
122+
def root(u, du, ddu, /, *, t):
119123
"""Evaluate a custom root for the harmonic oscillator."""
120-
deriv = ddu - vf(u, du)
124+
deriv = ddu - vf_2nd(u, du, t=t)
121125
hamil = hamiltonian_2nd(u, du) - H0
122126
return [deriv, hamil] # any PyTree goes
123127

@@ -145,10 +149,10 @@ def root(vf, u, du, ddu):
145149

146150
# +
147151

148-
ts1 = probdiffeq.constraint_root_ts1(root, ssm=ssm, ode_order=2)
152+
ts1 = probdiffeq.constraint_root_ts1(root, ssm=ssm)
149153
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
150154
solver_2nd = probdiffeq.solver_mle(
151-
vf_2nd, strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
155+
strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
152156
)
153157

154158
# -

docs/examples_basic/dynamic_output_scales.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
if not backend.has_been_selected:
4545
backend.select("jax") # ivp examples in jax
4646

47+
# Fail this notebook on NaN detection (to catch those in the CI)
48+
jax.config.update("jax_debug_nans", True)
49+
4750

4851
# -
4952

@@ -53,9 +56,10 @@
5356

5457

5558
@jax.jit
56-
def vf(*ys, t): # noqa: ARG001
59+
def vf(y, /, *, t):
5760
"""Evaluate the affine vector field."""
58-
return f(*ys, *f_args)
61+
del t
62+
return f(y, *f_args)
5963

6064

6165
# -
@@ -68,12 +72,12 @@ def vf(*ys, t): # noqa: ARG001
6872
init, ibm, ssm = probdiffeq.prior_wiener_integrated(
6973
tcoeffs, output_scale=1.0, ssm_fact="dense"
7074
)
71-
ts1 = probdiffeq.constraint_ode_ts1(ssm=ssm)
75+
ts1 = probdiffeq.constraint_ode_ts1(vf, ssm=ssm)
7276
strategy = probdiffeq.strategy_filter(ssm=ssm)
7377
dynamic = probdiffeq.solver_dynamic(
74-
vf, strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
78+
strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
7579
)
76-
mle = probdiffeq.solver_mle(vf, strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm)
80+
mle = probdiffeq.solver_mle(strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm)
7781

7882
# -
7983

docs/examples_basic/posterior_uncertainties.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@
2929

3030
from probdiffeq import ivpsolve, probdiffeq, taylor
3131

32+
# Fail this notebook on NaN detection (to catch those in the CI)
33+
jax.config.update("jax_debug_nans", True)
34+
3235

3336
@jax.jit
34-
def vf(y, *, t): # noqa: ARG001
37+
def vf(y, /, *, t): # noqa: ARG001
3538
"""Evaluate the Lotka-Volterra vector field."""
3639
y0, y1 = y[0], y[1]
3740

@@ -58,10 +61,10 @@ def vf(y, *, t): # noqa: ARG001
5861

5962
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=3)
6063
init, ibm, ssm = probdiffeq.prior_wiener_integrated(tcoeffs, ssm_fact="blockdiag")
61-
ts = probdiffeq.constraint_ode_ts1(ssm=ssm)
64+
ts = probdiffeq.constraint_ode_ts1(vf, ssm=ssm)
6265
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
6366

64-
solver = probdiffeq.solver_mle(vf, strategy=strategy, prior=ibm, constraint=ts, ssm=ssm)
67+
solver = probdiffeq.solver_mle(strategy=strategy, prior=ibm, constraint=ts, ssm=ssm)
6568
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
6669
solve = ivpsolve.solve_adaptive_save_at(solver=solver, errorest=errorest)
6770

0 commit comments

Comments
 (0)