Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8d8b81f
Rename LinearizationBackend into LinearizationFactoryBackend because …
pnkraemer Feb 16, 2026
b8157a0
Add 'Ode'-prefix to current linearisation backends to enable introduc…
pnkraemer Feb 16, 2026
eac1525
Clarify which linearization can handle high-order ODEs
pnkraemer Feb 16, 2026
ef188cc
Draft a RootTs1
pnkraemer Feb 16, 2026
b10f2e5
Start drafting a custom information operator tutorial
pnkraemer Feb 16, 2026
eeec4f4
Make solvers handle custom information operators
pnkraemer Feb 16, 2026
e986f01
Upgrade the dynamic solver
pnkraemer Feb 16, 2026
6e550c8
Set 'dense' default for the SSM
pnkraemer Feb 16, 2026
c0cde61
Leave some todos
pnkraemer Feb 16, 2026
e266755
Use adaptive steps in the custom information operator tutorial
pnkraemer Feb 17, 2026
c0dbb0b
Add documentation for the custom information operator
pnkraemer Feb 17, 2026
6b5eadc
Add documentation
pnkraemer Feb 17, 2026
b705e1e
Improve the py:light format
pnkraemer Feb 17, 2026
afb5a55
Include the custom information operator tutorial in the docs
pnkraemer Feb 17, 2026
05c89ea
Improve the python-markdown separation in the tutorials
pnkraemer Feb 17, 2026
4af9591
Improve docs
pnkraemer Feb 17, 2026
fab3343
Decrease the Tcoeff-std-increase because tests failed
pnkraemer Feb 17, 2026
fafc459
Make the STD increase opt-out
pnkraemer Feb 17, 2026
2265fa1
Update the benchmark
pnkraemer Feb 17, 2026
0dd9f08
Merge branch 'main' into custom-constraint
pnkraemer Feb 17, 2026
ca9199a
Increase initial damping by epsilon to stabilise initial update
pnkraemer Feb 17, 2026
83e42c8
Undo implicit conversion
pnkraemer Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/dev_docs/creating_example_notebook.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Probdiffeq hosts numerous tutorials and benchmarks that demonstrate the library.
Ensure the corresponding script is excluded under `mkdocs.yml -> exclude:`; if needed, add it there.

4. **Makefile:**
Add the new example or benchmark to the appropriate Makefile target (e.g., `examples-and-benchmarks`).
Check whether the new example or benchmark needs to be added to the appropriate Makefile target (e.g., `examples-and-benchmarks`). Generally, new files are detected automatically, but check nevertheless.

5. **Pyproject.toml:**
If your example requires external dependencies, list them under the `doc` optional dependencies in `pyproject.toml`.
Expand Down
21 changes: 18 additions & 3 deletions docs/examples_advanced/equinox_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@

from probdiffeq import ivpsolve, probdiffeq, taylor

# -


def solution_routine(while_loop):
"""Construct a parameter-to-solution function and an initial value."""
Expand Down Expand Up @@ -64,17 +62,32 @@ def simulate(init_val):
return simulate, init


# This is the default behaviour
# -


# This is the default behaviour.


# +


solve, x = solution_routine(jax.lax.while_loop)

try:
solution, gradient = jax.jit(jax.value_and_grad(solve))(x)
except ValueError as err:
print(f"Caught error:\n\t {err}")


# -


# This while-loop makes the solver differentiable


# +


def while_loop_func(*a, **kw):
"""Evaluate a bounded while loop."""
return equinox.internal.while_loop(*a, **kw, kind="bounded", max_steps=100)
Expand All @@ -87,3 +100,5 @@ def while_loop_func(*a, **kw):

print(solution)
print(gradient)

# -
25 changes: 25 additions & 0 deletions docs/examples_advanced/neural_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def main(num_data=100, epochs=500, print_every=50, hidden=(20,), lr=0.2):
plt.show()


# -

# +


def vf_neural_ode(*, hidden: tuple, t0: float, t1: float):
"""Build a neural ODE."""
f_args, mlp = model_mlp(hidden=hidden, shape_in=(2,), shape_out=(1,))
Expand All @@ -99,6 +104,11 @@ def vf(y, *, t, p):
return vf, (u0,), (t0, t1), f_args


# -

# +


def model_mlp(
*, hidden: tuple, shape_in: tuple = (), shape_out: tuple = (), activation=jnp.tanh
):
Expand Down Expand Up @@ -133,6 +143,11 @@ def fwd(w, x):
return unravel(p_init), fwd


# -

# +


def loss_log_marginal_likelihood(vf, *, t0):
"""Build a loss function from an ODE problem."""

Expand Down Expand Up @@ -177,6 +192,11 @@ def loss(
return loss


# -

# +


def train_step_optax(optimizer, loss):
"""Implement a training step using Optax."""

Expand All @@ -194,5 +214,10 @@ def update(params, opt_state, **loss_kwargs):
return update


# -

# +


if __name__ == "__main__":
main()
77 changes: 74 additions & 3 deletions docs/examples_advanced/parameter_estimation_blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ def vf(y, *, t): # noqa: ARG001
theta_guess = u0 # initial guess


# -

# +


def plot_solution(t, u, *, ax, marker=".", **plotting_kwargs):
"""Plot the IVP solution."""
for d in [0, 1]:
Expand Down Expand Up @@ -205,9 +209,15 @@ def solve_adaptive(theta, *, save_at):
save_at = jnp.linspace(t0, t1, num=250, endpoint=True)
solve_save_at = functools.partial(solve_adaptive, save_at=save_at)

# +

# -

# Visualise the initial guess and the data


# +


fig, ax = plt.subplots(figsize=(5, 3))

data_kwargs = {"alpha": 0.5, "color": "gray"}
Expand All @@ -220,6 +230,7 @@ def solve_adaptive(theta, *, save_at):
sol = solve_save_at(theta_guess)
ax = plot_solution(sol.t, sol.u.mean[0], ax=ax, **guess_kwargs)
plt.show()

# -

# ## Log-posterior densities via ProbDiffEq
Expand All @@ -244,17 +255,30 @@ def logposterior_fn(theta, *, data, ts, obs_stdev=0.1):
return logpdf_data + logpdf_prior


# -


# Fixed steps for reverse-mode differentiability:


# +


ts = jnp.linspace(t0, t1, endpoint=True, num=100)
data = solve_fixed(theta_true, ts=ts).u.mean[0][-1]

log_M = functools.partial(logposterior_fn, data=data, ts=ts)


# -

# +


print(jnp.exp(log_M(theta_true)), ">=", jnp.exp(log_M(theta_guess)), "?")

# -


# ## Sampling with BlackJAX
#
Expand All @@ -263,6 +287,9 @@ def logposterior_fn(theta, *, data, ts, obs_stdev=0.1):
# Set up a sampler.


# +


@functools.partial(jax.jit, static_argnames=["kernel", "num_samples"])
def inference_loop(rng_key, kernel, initial_state, num_samples):
"""Run BlackJAX' inference loop."""
Expand All @@ -277,13 +304,24 @@ def one_step(state, rng_key):
return states


# -


# Initialise the sampler, warm it up, and run the inference loop.


# +


initial_position = theta_guess
rng_key = jax.random.PRNGKey(0)

# -

# Warm up.

# +
# WARMUP

warmup = blackjax.window_adaptation(blackjax.nuts, log_M, progress_bar=True)

warmup_results, _ = warmup.run(rng_key, initial_position, num_steps=200)
Expand All @@ -296,21 +334,40 @@ def one_step(state, rng_key):
)
# -

# INFERENCE LOOP
# Inference loop


# +


rng_key, _ = jax.random.split(rng_key, 2)
states = inference_loop(
rng_key, kernel=nuts_kernel, initial_state=initial_state, num_samples=150
)

# -


# ## Visualisation
#
# Now that we have samples of $\theta$, let's plot the corresponding solutions:


# +


solution_samples = jax.vmap(solve_save_at)(states.position)

# -

# +

# Visualise the initial guess and the data


# +


fig, ax = plt.subplots()

sample_kwargs = {"color": "C0"}
Expand All @@ -330,20 +387,29 @@ def one_step(state, rng_key):
sol.t, sol.u.mean[0], ax=ax, linestyle="dashed", alpha=0.75, **guess_kwargs
)
plt.show()

# -

# The samples cover a perhaps surpringly large range of
# potential initial conditions, but lead to the "correct" data.
#
# In parameter space, this is what it looks like:


# +


plt.title("Posterior samples (parameter space)")
plt.plot(states.position[:, 0], states.position[:, 1], "o", alpha=0.5, markersize=4)
plt.plot(theta_true[0], theta_true[1], "P", label="Truth", markersize=8)
plt.plot(theta_guess[0], theta_guess[1], "P", label="Initial guess", markersize=8)
plt.legend()
plt.show()


# -


# Let's add the value of $M$ to the plot to see whether
# the sampler covers the entire region of interest.

Expand All @@ -360,7 +426,11 @@ def one_step(state, rng_key):
log_M_vmapped = jax.vmap(log_M_vmapped_x, in_axes=-1, out_axes=-1)
Zs = log_M_vmapped(Thetas)


# -

# +

fig, ax = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(8, 3))

ax_samples, ax_heatmap = ax
Expand All @@ -377,6 +447,7 @@ def one_step(state, rng_key):
im = ax_heatmap.contourf(Xs, Ys, jnp.exp(Zs), cmap="cividis", alpha=0.8)
plt.colorbar(im)
plt.show()

# -

# Looks great!
Expand Down
17 changes: 15 additions & 2 deletions docs/examples_advanced/parameter_estimation_optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

from probdiffeq import ivpsolve, probdiffeq

# +
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax

Expand Down Expand Up @@ -87,16 +86,21 @@ def solve(p):
data = solution_true.u.mean[0]
plt.plot(ts, data, "P-")
plt.show()

# -

# We make an initial guess, but it does not lead to a good data fit:

# +

solution_guess = solve(parameter_guess)
plt.plot(ts, data, color="k", linestyle="solid", linewidth=6, alpha=0.125)
plt.plot(ts, solution_guess.u.mean[0])
plt.show()


# -

# Use the probdiffeq functionality to compute a parameter-to-data fit function.
#
# This incorporates the likelihood of the data under the distribution induced
Expand All @@ -123,9 +127,10 @@ def parameter_to_data_fit(parameters_, /, standard_deviation=1e-1):
# We can differentiate the function forward- and reverse-mode
# (the latter is possible because we use fixed steps)

# +
parameter_to_data_fit(parameter_guess)
sensitivities(parameter_guess)

# -

# Now, enter optax: build an optimizer,
# and optimise the parameter-to-model-fit function.
Expand All @@ -151,7 +156,11 @@ def update(params, opt_state):
optim = optax.adam(learning_rate=1e-2)
update_fn = build_update_fn(optimizer=optim, loss_fn=parameter_to_data_fit)

# -

# +


p = parameter_guess
state = optim.init(p)

Expand All @@ -165,7 +174,11 @@ def update(params, opt_state):

# The solution looks much better:

# +

solution_better = solve(p)
plt.plot(ts, data, color="k", linestyle="solid", linewidth=6, alpha=0.125)
plt.plot(ts, solution_better.u.mean[0])
plt.show()

# -
Loading