Skip to content
11 changes: 6 additions & 5 deletions docs/examples_advanced/equinox_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@

from probdiffeq import ivpsolve, probdiffeq, taylor

# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)


def solution_routine(while_loop):
"""Construct a parameter-to-solution function and an initial value."""

@jax.jit
def vf(y, *, t): # noqa: ARG001
def vf(y, /, *, t): # noqa: ARG001
"""Evaluate the vector field."""
return 0.5 * y * (1 - y)

Expand All @@ -40,12 +43,10 @@ def vf(y, *, t): # noqa: ARG001

tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
init, ibm, ssm = probdiffeq.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
ts0 = probdiffeq.constraint_ode_ts0(ode_order=1, ssm=ssm)
ts0 = probdiffeq.constraint_ode_ts0(vf, ssm=ssm)

strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
solver = probdiffeq.solver(
vf, strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm
)
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
solve_adaptive = ivpsolve.solve_adaptive_terminal_values(
solver=solver, errorest=errorest, while_loop=while_loop
Expand Down
17 changes: 10 additions & 7 deletions docs/examples_advanced/neural_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@

from probdiffeq import ivpsolve, probdiffeq

# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)


def main(num_data=100, epochs=500, print_every=50, hidden=(20,), lr=0.2):
"""Train a neural ODE using diffusion tempering."""
Expand Down Expand Up @@ -96,7 +99,7 @@ def vf_neural_ode(*, hidden: tuple, t0: float, t1: float):
u0 = jnp.asarray([0.0])

@jax.jit
def vf(y, *, t, p):
def vf(y, /, *, t, p):
"""Evaluate the neural ODE vector field."""
y_and_t = jnp.concatenate([y, t[None]])
return mlp(p, y_and_t)
Expand Down Expand Up @@ -167,14 +170,14 @@ def loss(
init, ibm, ssm = probdiffeq.prior_wiener_integrated(
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = probdiffeq.constraint_ode_ts0(ssm=ssm)

def vf_p(y, /, *, t):
return vf(y, t=t, p=p)

ts0 = probdiffeq.constraint_ode_ts0(vf_p, ssm=ssm)
strategy = probdiffeq.strategy_smoother_fixedinterval(ssm=ssm)
solver_ts0 = probdiffeq.solver(
lambda *a, **kw: vf(*a, **kw, p=p),
strategy=strategy,
prior=ibm,
constraint=ts0,
ssm=ssm,
strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm
)

# Solve
Expand Down
6 changes: 3 additions & 3 deletions docs/examples_advanced/parameter_estimation_blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@


@jax.jit
def vf(y, *, t): # noqa: ARG001
def vf(y, /, *, t): # noqa: ARG001
"""Evaluate the Lotka-Volterra vector field."""
return f(y, *f_args)

Expand All @@ -180,9 +180,9 @@ def plot_solution(t, u, *, ax, marker=".", **plotting_kwargs):
init, ibm, ssm = probdiffeq.prior_wiener_integrated(
tcoeffs, output_scale=10.0, ssm_fact="isotropic"
)
ts0 = probdiffeq.constraint_ode_ts0(ssm=ssm)
ts0 = probdiffeq.constraint_ode_ts0(vf, ssm=ssm)
strategy = probdiffeq.strategy_filter(ssm=ssm)
solver = probdiffeq.solver(vf, strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm)
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm)


@jax.jit
Expand Down
18 changes: 10 additions & 8 deletions docs/examples_advanced/parameter_estimation_optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax

# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)


# -


Expand Down Expand Up @@ -64,15 +68,13 @@ def solve(p):
init, ibm, ssm = probdiffeq.prior_wiener_integrated(
tcoeffs, output_scale=10.0, ssm_fact="isotropic"
)
ts0 = probdiffeq.constraint_ode_ts0(ssm=ssm)

def vf_p(y, /, *, t):
return vf(y, t=t, p=p)

ts0 = probdiffeq.constraint_ode_ts0(vf_p, ssm=ssm)
strategy = probdiffeq.strategy_smoother_fixedinterval(ssm=ssm)
solver = probdiffeq.solver(
jax.jit(lambda y, t: vf(y, t, p=p)),
strategy=strategy,
prior=ibm,
constraint=ts0,
ssm=ssm,
)
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm)
solve = ivpsolve.solve_fixed_grid(solver=solver)
return solve(init, grid=ts)

Expand Down
14 changes: 7 additions & 7 deletions docs/examples_advanced/solve_pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,32 @@
import jax.numpy as jnp
import matplotlib.pyplot as plt

from probdiffeq import ivpsolve, probdiffeq, taylor
from probdiffeq import ivpsolve, probdiffeq

jax.config.update("jax_enable_x64", True)
# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)


def main():
"""Simulate a PDE."""
key = jax.random.PRNGKey(1)
f, (u0,), (t0, t1) = fhn_2d(key, num=40, t1=10.0)

@jax.jit
def vf(y, *, t): # noqa: ARG001
def vf(y, /, *, t): # noqa: ARG001
"""Evaluate the dynamics of the PDE."""
return f(y)

print("Problem dimension:", u0.size)

# Set up a state-space model
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
tcoeffs = [u0, vf(u0, t=t0)]
init, ibm, ssm = probdiffeq.prior_wiener_integrated(tcoeffs, ssm_fact="blockdiag")

# Build a solver
ts = probdiffeq.constraint_ode_ts1(ssm=ssm)
ts = probdiffeq.constraint_ode_ts1(vf, ssm=ssm)
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
solver = probdiffeq.solver_dynamic(
vf, ssm=ssm, strategy=strategy, prior=ibm, constraint=ts
ssm=ssm, strategy=strategy, prior=ibm, constraint=ts
)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)

Expand Down
19 changes: 13 additions & 6 deletions docs/examples_basic/conditioning_on_zero_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

"""Demonstrate how probabilistic solvers work via conditioning on constraints."""

import functools

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
Expand All @@ -32,6 +34,11 @@
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax


# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)


# -

# Create an ODE problem.
Expand All @@ -40,8 +47,9 @@


@jax.jit
def vector_field(y, t): # noqa: ARG001
def vector_field(y, /, *, t):
"""Evaluate the logistic ODE vector field."""
del t
return 10.0 * y * (2.0 - y)


Expand Down Expand Up @@ -78,11 +86,9 @@ def vector_field(y, t): # noqa: ARG001
init, ibm, ssm = probdiffeq.prior_wiener_integrated(
tcoeffs, output_scale=1.0, ssm_fact="dense"
)
ts1 = probdiffeq.constraint_ode_ts1(ssm=ssm)
ts1 = probdiffeq.constraint_ode_ts1(vector_field, ssm=ssm)
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
solver = probdiffeq.solver(
vector_field, strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
)
solver = probdiffeq.solver(strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)

dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,))
Expand Down Expand Up @@ -126,7 +132,8 @@ def vector_field(y, t): # noqa: ARG001

def residual(x, t):
"""Evaluate the ODE residual."""
return x[1] - jax.vmap(jax.vmap(vector_field), in_axes=(0, None))(x[0], t)
vf_wrapped = functools.partial(vector_field, t=t)
return x[1] - jax.vmap(jax.vmap(vf_wrapped))(x[0])


residual_prior = residual(samples_prior, ts)
Expand Down
18 changes: 11 additions & 7 deletions docs/examples_basic/custom_information_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@

from probdiffeq import ivpsolve, probdiffeq

# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)


# -


Expand All @@ -43,7 +47,7 @@


@jax.jit
def vf_1st(y, *, t):
def vf_1st(y, /, *, t):
"""Evaluate the harmonic oscillator dynamics."""
u, du = jnp.split(y, 2)
return jnp.concatenate([du, vf_2nd(u, du, t=t)])
Expand Down Expand Up @@ -90,10 +94,10 @@ def hamiltonian_2nd(u, du):
tcoeffs = [u0_1st, zeros, zeros]
tcoeffs_std = [zeros, ones, ones]
init, ibm, ssm = probdiffeq.prior_wiener_integrated(tcoeffs, tcoeffs_std=tcoeffs_std)
ts1 = probdiffeq.constraint_ode_ts1(ssm=ssm)
ts1 = probdiffeq.constraint_ode_ts1(vf_1st, ssm=ssm)
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
solver_1st = probdiffeq.solver_mle(
vf_1st, strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
solve = ivpsolve.solve_adaptive_save_at(solver=solver_1st, errorest=errorest)
Expand All @@ -115,9 +119,9 @@ def hamiltonian_2nd(u, du):
# +


def root(vf, u, du, ddu):
def root(u, du, ddu, /, *, t):
"""Evaluate a custom root for the harmonic oscillator."""
deriv = ddu - vf(u, du)
deriv = ddu - vf_2nd(u, du, t=t)
hamil = hamiltonian_2nd(u, du) - H0
return [deriv, hamil] # any PyTree goes

Expand Down Expand Up @@ -145,10 +149,10 @@ def root(vf, u, du, ddu):

# +

ts1 = probdiffeq.constraint_root_ts1(root, ssm=ssm, ode_order=2)
ts1 = probdiffeq.constraint_root_ts1(root, ssm=ssm)
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
solver_2nd = probdiffeq.solver_mle(
vf_2nd, strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
)

# -
Expand Down
14 changes: 9 additions & 5 deletions docs/examples_basic/dynamic_output_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax

# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)


# -

Expand All @@ -53,9 +56,10 @@


@jax.jit
def vf(*ys, t): # noqa: ARG001
def vf(y, /, *, t):
"""Evaluate the affine vector field."""
return f(*ys, *f_args)
del t
return f(y, *f_args)


# -
Expand All @@ -68,12 +72,12 @@ def vf(*ys, t): # noqa: ARG001
init, ibm, ssm = probdiffeq.prior_wiener_integrated(
tcoeffs, output_scale=1.0, ssm_fact="dense"
)
ts1 = probdiffeq.constraint_ode_ts1(ssm=ssm)
ts1 = probdiffeq.constraint_ode_ts1(vf, ssm=ssm)
strategy = probdiffeq.strategy_filter(ssm=ssm)
dynamic = probdiffeq.solver_dynamic(
vf, strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm
)
mle = probdiffeq.solver_mle(vf, strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm)
mle = probdiffeq.solver_mle(strategy=strategy, prior=ibm, constraint=ts1, ssm=ssm)

# -

Expand Down
9 changes: 6 additions & 3 deletions docs/examples_basic/posterior_uncertainties.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@

from probdiffeq import ivpsolve, probdiffeq, taylor

# Fail this notebook on NaN detection (to catch those in the CI)
jax.config.update("jax_debug_nans", True)


@jax.jit
def vf(y, *, t): # noqa: ARG001
def vf(y, /, *, t): # noqa: ARG001
"""Evaluate the Lotka-Volterra vector field."""
y0, y1 = y[0], y[1]

Expand All @@ -58,10 +61,10 @@ def vf(y, *, t): # noqa: ARG001

tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=3)
init, ibm, ssm = probdiffeq.prior_wiener_integrated(tcoeffs, ssm_fact="blockdiag")
ts = probdiffeq.constraint_ode_ts1(ssm=ssm)
ts = probdiffeq.constraint_ode_ts1(vf, ssm=ssm)
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)

solver = probdiffeq.solver_mle(vf, strategy=strategy, prior=ibm, constraint=ts, ssm=ssm)
solver = probdiffeq.solver_mle(strategy=strategy, prior=ibm, constraint=ts, ssm=ssm)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
solve = ivpsolve.solve_adaptive_save_at(solver=solver, errorest=errorest)

Expand Down
Loading