Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
ba2523d
Draft a new rejection loop that has less state
pnkraemer Feb 11, 2026
4629af8
Tighten solve_and_save_at code
pnkraemer Feb 11, 2026
386c9df
Improve the clarity of the rejection loop
pnkraemer Feb 11, 2026
48732c7
Draft a separate error estimator
pnkraemer Feb 11, 2026
d9b6c3c
Move stats to ivpsolvers.py
pnkraemer Feb 11, 2026
47daaf6
Update the dynamic solver
pnkraemer Feb 11, 2026
7492bf3
Improve typing in ivpsolve.py
pnkraemer Feb 12, 2026
15ac05e
test_ivpsolve tests pass cleanly now
pnkraemer Feb 12, 2026
08c4d8b
Tests seem to pass
pnkraemer Feb 12, 2026
72d8f7b
Rename test_ivpsolvers to test_probdiffeq in preparation of renaming …
pnkraemer Feb 12, 2026
3bc1de7
Delete empty probdiffeq.stats module
pnkraemer Feb 12, 2026
145671d
Rename probdiffeq.ivpsolvers to probdiffeq.probdiffeq because this is…
pnkraemer Feb 12, 2026
b222e06
Fix some linter issues
pnkraemer Feb 12, 2026
549d272
Move solve_and_save_every_step to test utilities and simplify the ivp…
pnkraemer Feb 12, 2026
a4a98d2
Make solve_fixed_grid return a Callable now
pnkraemer Feb 12, 2026
6cc23fe
Adjust many of the solver tests to the new API
pnkraemer Feb 12, 2026
534c981
Fix tests
pnkraemer Feb 12, 2026
738dc9e
Make bounded while loops easy to use
pnkraemer Feb 12, 2026
c36cb2d
Fix all tests
pnkraemer Feb 12, 2026
931d29b
Save...
pnkraemer Feb 12, 2026
0f69c6f
Adjust the diffusion tempering parameters because the numerics seem t…
pnkraemer Feb 12, 2026
51f9242
Fix two notebooks
pnkraemer Feb 12, 2026
f53677d
Update the PDE example
pnkraemer Feb 12, 2026
c55d752
Append initial condition inside markov_marginals
pnkraemer Feb 12, 2026
2486fae
Fix remaining notebooks
pnkraemer Feb 12, 2026
cd1edac
Update benchmarks
pnkraemer Feb 12, 2026
c1853ef
Cache function evaluations for error estimation
pnkraemer Feb 12, 2026
687511d
Replace corrections with lightweight wrappers around the linearisatio…
pnkraemer Feb 13, 2026
b827394
Update more linearisations
pnkraemer Feb 13, 2026
7d83f12
Block diagonal linearisations
pnkraemer Feb 13, 2026
6f42b95
Finish updating the statistical linear regression
pnkraemer Feb 13, 2026
a093066
Finish the constraint-API refactor and update a large chunk of the tests
pnkraemer Feb 13, 2026
d82c181
Improve documentation
pnkraemer Feb 13, 2026
f1f8199
Fix linter errors
pnkraemer Feb 13, 2026
ea8d101
Update tests
pnkraemer Feb 13, 2026
87463c0
Remove damp parameter from IWP initialisation
pnkraemer Feb 13, 2026
2d70ac0
Update benchmark code
pnkraemer Feb 13, 2026
89e68e3
Write docstrings for probdiffeq.py
pnkraemer Feb 13, 2026
87fb705
Improve formatting in docstrings
pnkraemer Feb 13, 2026
4b8b7df
Include a VectorField protocol to communicate which vector fields are…
pnkraemer Feb 13, 2026
c75a945
Fix the VF type
pnkraemer Feb 13, 2026
e068ea6
Include more info in the readme
pnkraemer Feb 13, 2026
0250ce9
Document the TimeStepState
pnkraemer Feb 13, 2026
5a734e5
Remove dead code and improve typing
pnkraemer Feb 14, 2026
94bc978
Rename backend.numpy into backend.np to streamling imports
pnkraemer Feb 14, 2026
9cceb43
Remove unused backend.config
pnkraemer Feb 14, 2026
12f74ad
Rename backend.functools to backend.func because it is more compact
pnkraemer Feb 14, 2026
e0cc070
Rename backend.tree_util to backend.tree to reflect JAX's API (and be…
pnkraemer Feb 14, 2026
0e327dd
Move backend.tree_array_util content into backend.tree to reduce the …
pnkraemer Feb 14, 2026
35c238f
Improve controller typing
pnkraemer Feb 14, 2026
1f4d063
Improve docstrings
pnkraemer Feb 14, 2026
6be9096
Rename solution.posterior into solution.full
pnkraemer Feb 14, 2026
326401e
Undo some solution typing changes
pnkraemer Feb 14, 2026
7796a8c
Make the README batch show the CI of the main branch, not the most re…
pnkraemer Feb 14, 2026
d3d5efa
Fix docs
pnkraemer Feb 14, 2026
ed94830
Rename backend.control_flow to backend.flow to shorten
pnkraemer Feb 14, 2026
f5eb0eb
Rename backend.containers into backend.structs to make code more compact
pnkraemer Feb 14, 2026
47425e9
Fix last failing test
pnkraemer Feb 14, 2026
86ab9ca
Update meta-documents (eg troubleshooting)
pnkraemer Feb 16, 2026
e96a8d7
Upgrade pre-commit hook and fix resulting complaints
pnkraemer Feb 16, 2026
56a37e3
Remove unused dependency
pnkraemer Feb 16, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ repos:
- id: end-of-file-fixer
- id: check-merge-conflict
- repo: https://github.com/lyz-code/yamlfix/
rev: 1.18.0
rev: 1.19.1
hooks:
- id: yamlfix
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.12.12
rev: v0.15.1
hooks:
# Run the linter.
- id: ruff-check
args: [--fix]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.1
rev: v1.19.1
hooks:
- id: mypy
args: [--ignore-missing-imports, --check-untyped-defs]
53 changes: 50 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# probdiffeq

[![CI](https://github.com/pnkraemer/probdiffeq/workflows/ci/badge.svg)](https://github.com/pnkraemer/probdiffeq/actions)
[![CI](https://github.com/pnkraemer/probdiffeq/workflows/ci/badge.svg?branch=main)](https://github.com/pnkraemer/probdiffeq/actions)
[![PyPI version](https://img.shields.io/pypi/v/probdiffeq.svg)](https://pypi.python.org/pypi/probdiffeq)
[![License](https://img.shields.io/pypi/l/probdiffeq.svg)](https://pypi.python.org/pypi/probdiffeq)
[![Python versions](https://img.shields.io/pypi/pyversions/probdiffeq.svg)](https://pypi.python.org/pypi/probdiffeq)
Expand Down Expand Up @@ -100,9 +100,55 @@ Link to the paper: [PDF](https://arxiv.org/abs/2410.10530).
Link to the experiments:
[Code for experiments](https://github.com/pnkraemer/code-adaptive-prob-ode-solvers).

📌 Algorithms in Probdiffeq are based on multiple research papers. If you’re unsure which to cite, feel free to reach out.

---
Algorithms in **Probdiffeq** are based on multiple research papers. If you’re unsure which to cite, feel free to reach out.

A (subjective, probdiffeq-centric) list of relevant work includes:


- Numerically robust implementations of probabilistic solvers:

> Nicholas Krämer & Philipp Hennig (2024). Stable implementation of probabilistic ODE solvers. Journal of Machine Learning Research, 25(111), 1–29.

All suggestions made in this work are critical to Probdiffeq (and other libraries). They are rarely discussed though, and almost taken for granted by now.


- State-space model factorisations:

> Nicholas Krämer, Nathanael Bosch, Jonathan Schmidt & Philipp Hennig (2022). Probabilistic ODE solutions in millions of dimensions. In ICML 2022, 11634–11649. PMLR.

Every time Probdiffeq uses state-space model factorisations, it follows the recommendations in this work.

- Adaptive step-size selection:

> Michael Schober, Simo Särkkä & Philipp Hennig (2019). A probabilistic model for the numerical solution of initial value problems. Statistics and Computing, 29(1), 99–122.

> Nathanael Bosch, Philipp Hennig & Filip Tronarp (2021). Calibrated adaptive probabilistic ODE solvers. In AISTATS 2021, 3466–3474. PMLR.

> Nicholas Krämer (2025). Adaptive Probabilistic ODE Solvers Without Adaptive Memory Requirements. In Kanagawa, M., Cockayne, J., Gessner, A., & Hennig, P. (Eds.), Proceedings of the First International Conference on Probabilistic Numerics, 12–24. PMLR.

- Constraints, linearisation, and information operators:

> Bosch, Nathanael, Filip Tronarp, and Philipp Hennig. "Pick-and-mix information operators for probabilistic ODE solvers." International Conference on Artificial Intelligence and Statistics. PMLR, 2022.

>Tronarp, Filip, et al. "Probabilistic solutions to ordinary differential equations as nonlinear Bayesian filtering: a new perspective." Statistics and Computing 29.6 (2019): 1297-1315.

See also the Linearisation-chapter in:

> Krämer, Nicholas. Implementing probabilistic numerical solvers for differential equations. Diss. Dissertation, Tübingen, Universität Tübingen, 2024.

which describes some methods not mentioned anywhere else.

- Parameter estimation:

> Kersting, H., Krämer, N., Schiegg, M., Daniel, C., Tiemann, M., & Hennig, P. (2020, November). Differentiable likelihoods for fast inversion of’likelihood-free’dynamical systems. In International Conference on Machine Learning (pp. 5198-5208). PMLR.

> Tronarp, Filip, Nathanael Bosch, and Philipp Hennig. "Fenrir: Physics-enhanced regression for initial value problems." International Conference on Machine Learning. PMLR, 2022.

> Beck, J., Bosch, N., Deistler, M., Kadhim, K. L., Macke, J. H., Hennig, P., & Berens, P. (2024, July). Diffusion Tempering Improves Parameter Estimation with Probabilistic Integrators for Ordinary Differential Equations. In International Conference on Machine Learning (pp. 3305-3326). PMLR.


Anything missing? Reach out!

## Versioning

Expand All @@ -111,6 +157,7 @@ Probdiffeq follows **0.MINOR.PATCH** until its first stable release:
- **MINOR** → breaking changes

See [semantic versioning](https://semver.org/).
Notably, Probdiffeq's API is not guaranteed to be stable, but we do our best to follow the versioning scheme so that downstream projects remain reproducible.

---

Expand Down
1 change: 1 addition & 0 deletions docs/_stylesheets/extra.css
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

:root {
--gray-dark: #222222;
--gray: #333333;
Expand Down
1 change: 0 additions & 1 deletion docs/api_docs/ivpsolvers.md

This file was deleted.

1 change: 1 addition & 0 deletions docs/api_docs/probdiffeq.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: probdiffeq.probdiffeq
1 change: 0 additions & 1 deletion docs/api_docs/stats.md

This file was deleted.

37 changes: 19 additions & 18 deletions docs/choosing_a_solver.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Choosing a solver

Good solvers are problem-dependent. Nevertheless, some guidelines exist:
Good solvers are problem-dependent. However, some guidelines exist:

## State-space model factorisation

* If your problem is scalar-valued (`shape=()`), use a `scalar` implementation. Of course, you are always welcome to transform your problem into one with shape `(1,)` and use a vector-valued solver (not all features are implemented for scalar models).
* If your problem is vector-valued, be aware that different implementation choices imply different modelling choices.
* If your problem is scalar-valued (`shape=()`), use a dense factorisation. All factorisations have the same complexity for scalar models, but dense factorisations offer the most comprehensive solver suite.

If you don't care about modelling choices:
* If your problem is vector-valued, be aware that different implementation choices imply different modelling choices.
However, if you don't care too much about modelling choices:

* If your problem is high-dimensional, use a `blockdiag` or `isotropic` implementation.
* If your problem is medium-dimensional, use any implementations.
Expand All @@ -21,29 +21,30 @@ If you don't care about modelling choices:
If your problem is stiff, use a a `dense` implementation in combination with a
correction scheme that employs first-order linearisation;
for instance, `ts1` or `slr1`.
Zeroth-order approximation and too-aggressive state-space model factorisation
will likely fail.
Zeroth-order approximation and isotropic/blockdiag factorisations often fail for stiff problems.

If your problem is stiff and high-dimensional: try first-order linearisation with a block-diagonal factorisation.
If that does not work: let me know what you come up with...
If that does not work: good luck; probabilistic solvers for problems that are stiff
*and* high-dimensional are a bit of an open problem as of writing this.

## Filters vs smoothers

Almost always, use a `ivpsolvers.strategy_filter` strategy for `simulate_terminal_values`,
a `ivpsolvers.strategy_smoother` strategy for `solve_adaptive_save_every_step`,
and a `ivpsolvers.strategy_fixedpoint` strategy for `solve_adaptive_save_at`.
Use either a filter (if you must) or a smoother (recommended) for `solve_fixed_step`.
Other combinations are possible, but rather rare
(and require some understanding of the underlying statistical concepts).
As a rule of thumb, use a `ivpsolvers.strategy_filter` strategy for `simulate_terminal_values`,
a `ivpsolvers.strategy_smoother_fixedpoint` strategy for `solve_adaptive_save_at`,
and a `ivpsolvers.strategy_smoother_fixedinterval` strategy for `solve_fixed_step`.
Other combinations are possible, but rare.


## Calibration
Use a `solvers.solver_dynamic` solver if you expect that the output scale of your IVP solution varies greatly.
Otherwise, use an `solvers.solver_mle` solver.
Try a `solvers.solver` for parameter-estimation.
Use a `solvers.solver_dynamic` solver if you expect that the output scale of your differential equation
solution varies greatly (eg for first-order, linear ODEs; see the tutorials).
Otherwise, use an `solvers.solver_mle` solver for plain simulation problems,
and a `solvers.solver` for parameter-estimation.

## Miscellaneous
If you use a `ts0`, choose an `isotropic` factorisation instead of a `dense` factorisation.
They do the same, but the `isotropic` factorisation is cheaper.
They are mathematically equivalent, but the `isotropic` factorisation is faster.


These guidelines are a work in progress and may change soon. If you have any input, let me know!
## Future guidelines
These guidelines are a work in progress and may change at any point. If you have any input, reach out.
3 changes: 2 additions & 1 deletion docs/dev_docs/creating_example_notebook.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ Probdiffeq hosts numerous tutorials and benchmarks that demonstrate the library.
- `docs/examples_advanced/example-name.ipynb`
Choose a meaningful name (e.g., `work-precision-hires`, `demonstrate-calibration`). The notebook should run the full example/benchmark and produce its plots. Ensure execution time stays well below one minute to keep CI manageable.

If your example requires external dependencies (e.g., sampling or optimization libraries), place it in `examples_advanced`.
If your example requires external dependencies (e.g., sampling or optimization libraries), place it in `examples_advanced`. If it is a benchmark, place it in `examples_benchmarks`. Otherwise, place it in
`examples_basic`.

2. **Sync to py:light:**
Install documentation dependencies and pre-commit hooks if you haven't already:
Expand Down
18 changes: 0 additions & 18 deletions docs/dev_docs/public_api.md

This file was deleted.

73 changes: 34 additions & 39 deletions docs/examples_advanced/equinox_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,15 @@
import jax
import jax.numpy as jnp

from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.backend import control_flow
from probdiffeq import ivpsolve, probdiffeq, taylor

# -

# Overwrite the while-loop (via a context manager):


# +
def while_loop_func(*a, **kw):
"""Evaluate a bounded while loop."""
return equinox.internal.while_loop(*a, **kw, kind="bounded", max_steps=100)


context_compute_gradient = control_flow.context_overwrite_while_loop(while_loop_func)
# -

# The rest is the similar to the "easy example" in the quickstart,
# except for simulating adaptively and
# computing the value and the gradient
# (which is impossible without the specialised while-loop implementation).


def solution_routine():
def solution_routine(while_loop):
"""Construct a parameter-to-solution function and an initial value."""

@jax.jit
def vf(y, *, t): # noqa: ARG001
"""Evaluate the vector field."""
return 0.5 * y * (1 - y)
Expand All @@ -58,37 +41,49 @@ def vf(y, *, t): # noqa: ARG001
u0 = jnp.asarray([0.1])

tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(vf, ode_order=1, ssm=ssm)

strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
init, ibm, ssm = probdiffeq.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
ts0 = probdiffeq.constraint_ode_ts0(ode_order=1, ssm=ssm)

strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
solver = probdiffeq.solver(
vf, strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm
)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
solve_adaptive = ivpsolve.solve_adaptive_terminal_values(
solver=solver, errorest=errorest, while_loop=while_loop
)

def simulate(init_val):
"""Evaluate the parameter-to-solution function."""
sol = ivpsolve.solve_adaptive_terminal_values(
init_val, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
)
sol = solve_adaptive(init_val, t0=t0, t1=t1, atol=1e-3, rtol=1e-3)

# Any scalar function of the IVP solution would do
return jnp.dot(sol.u[0], sol.u[0])
# Try the log-marginal-likelihood losses (see the other tutorials).
return jnp.dot(sol.u.mean[0], sol.u.mean[0])

return simulate, init


# This is the default behaviour
solve, x = solution_routine(jax.lax.while_loop)

try:
solve, x = solution_routine()
solution, gradient = jax.value_and_grad(solve)(x)
solution, gradient = jax.jit(jax.value_and_grad(solve))(x)
except ValueError as err:
print(f"Caught error:\n\t {err}")

with context_compute_gradient:
# Construct the solution routine inside the context
solve, x = solution_routine()
# This while-loop makes the solver differentiable


def while_loop_func(*a, **kw):
"""Evaluate a bounded while loop."""
return equinox.internal.while_loop(*a, **kw, kind="bounded", max_steps=100)


solve, x = solution_routine(while_loop=while_loop_func)

# Compute gradients
solution, gradient = jax.value_and_grad(solve)(x)
# Compute gradients
solution, gradient = jax.jit(jax.value_and_grad(solve))(x)

print(solution)
print(gradient)
print(solution)
print(gradient)
36 changes: 21 additions & 15 deletions docs/examples_advanced/neural_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@
import matplotlib.pyplot as plt
import optax

from probdiffeq import ivpsolve, ivpsolvers, stats
from probdiffeq import ivpsolve, probdiffeq


def main(num_data=100, epochs=1_000, print_every=100, hidden=(20,), lr=0.2):
def main(num_data=100, epochs=500, print_every=50, hidden=(20,), lr=0.2):
"""Train a neural ODE using diffusion tempering."""
# Create some data and construct a neural ODE
grid = jnp.linspace(0, 1, num=num_data)
data = jnp.sin(2.5 * jnp.pi * grid) * jnp.pi * grid
stdev = 1e-1
output_scale = 1e2
vf, u0, (t0, t1), f_args = vf_neural_ode(hidden=hidden, t0=0.0, t1=1)
output_scale = 1e4
vf, u0, (t0, _t1), f_args = vf_neural_ode(hidden=hidden, t0=0.0, t1=1)

# Create a loss (this is where probabilistic numerics enters!)
loss = loss_log_marginal_likelihood(vf=vf, t0=t0)
Expand All @@ -44,7 +44,7 @@ def main(num_data=100, epochs=1_000, print_every=100, hidden=(20,), lr=0.2):
# Plot the data and the initial guess
plt.title(f"Initial estimate | Loss: {loss0:.2f}")
plt.plot(grid, data, "x", label="Data", color="C0")
plt.plot(grid, info0["sol"].u[0], "-", label="Estimate", color="C1")
plt.plot(grid, info0["sol"].u.mean[0], "-", label="Estimate", color="C1")
plt.legend()
plt.show()

Expand Down Expand Up @@ -79,8 +79,8 @@ def main(num_data=100, epochs=1_000, print_every=100, hidden=(20,), lr=0.2):
# Plot the results
plt.title(f"Final estimate | Loss: {info['loss']:.2f}")
plt.plot(grid, data, "x", label="Data", color="C0")
plt.plot(grid, info0["sol"].u[0], "-", label="Initial estimate", color="C1")
plt.plot(grid, info["sol"].u[0], "-", label="Final estimate", color="C2")
plt.plot(grid, info0["sol"].u.mean[0], "-", label="Initial estimate", color="C1")
plt.plot(grid, info["sol"].u.mean[0], "-", label="Final estimate", color="C2")
plt.legend()
plt.show()

Expand Down Expand Up @@ -149,22 +149,28 @@ def loss(
"""Loss function: log-marginal likelihood of the data."""
# Build a solver
tcoeffs = (*u0, vf(*u0, t=t0, p=p))
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
init, ibm, ssm = probdiffeq.prior_wiener_integrated(
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(lambda *a, **kw: vf(*a, **kw, p=p), ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
ts0 = probdiffeq.constraint_ode_ts0(ssm=ssm)
strategy = probdiffeq.strategy_smoother_fixedinterval(ssm=ssm)
solver_ts0 = probdiffeq.solver(
lambda *a, **kw: vf(*a, **kw, p=p),
strategy=strategy,
prior=ibm,
constraint=ts0,
ssm=ssm,
)

# Solve
sol = ivpsolve.solve_fixed_grid(init, grid=grid, solver=solver_ts0, ssm=ssm)
solve = ivpsolve.solve_fixed_grid(solver=solver_ts0)
sol = solve(init, grid=grid)

# Evaluate loss
marginal_likelihood = stats.log_marginal_likelihood(
marginal_likelihood = strategy.log_marginal_likelihood(
data[:, None],
standard_deviation=jnp.ones_like(grid) * stdev,
posterior=sol.posterior,
ssm=sol.ssm,
posterior=sol.solution_full,
)
return -1 * marginal_likelihood, {"sol": sol}

Expand Down
Loading