Refactor API: Separate correction vs strategy vs solver, error estimation vs time-stepping, static vs dynamic args, and more#846
Merged
pnkraemer merged 61 commits intoFeb 16, 2026
Conversation
…probdiffeq.ivpsolvers
… the only place where probabilistic numerics happens now (the other modules are helpers)
…cause it is shorter)
…number of backends
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
I've used
probdiffeqregularly in recent times, and parts of the API made my life surprisingly hard. While working with it, I refactored several components to make the library easier to extend by users (including myself). The result is (hopefully) a more readable and maintainable codebase.Warning: Big refactor. Description below is still a work in progress.
This PR addresses several long-standing API issues:
ivpsolve.solve_adaptive*madejit/vmap/differentiation unnecessarily tedious.All of these are resolved now 🎉 (Is v1.0.0 getting closer?)
Notable changes
Adaptive time-stepping is now independent of solvers. Instead of
solve(ivp, adaptive(solver, control, ...)), one now callssolve(ivp, solver, errorest, control, ...).This flattened hierarchy makes each component smaller and thus easier to maintain. Controllers and adaptive logic are now completely solver-agnostic and live in
ivpsolve.py.stats.pyremoved. Its functionality has been moved into the respective strategy implementations. Since they share SSM and solution/argument types, co-locating them simplifies the structure. Offgrid marginals are solver methods for the same reason.ivpsolvers.pyrenamed toprobdiffeq.py. After the refactor, it contains the core probabilistic solver code (and essentially no non-probnum logic except quadrature rules). The new name makes the core contribution clearer. (ivpsolve.pyandtaylor.pycould, in principle, be standalone libraries.)Renamed smoothers for clarity.
smootheris nowsmoother_fixedinterval.fixedpointis nowsmoother_fixedpoint. Their relationship is now more explicit.Error estimation decoupled from solvers. Error estimators are now constructed via
probdiffeq.errorest_*()and are only required for adaptive runs. Corrections no longer implement error estimation. Currently, two error estimators are provided.Corrections renamed. Since they no longer “correct” in the previous sense, they are now named
constraint_ode_ts0/ts1/....Builder pattern made explicit. Strategies and solvers already followed a builder-style design; this is now more visible (e.g., explicit
finalize()methods). The code is closer to textbook OOP — still fully compatible with JAX — and easier to maintain and explain.solve_and_save_every_step deprecated The non-jittable adaptive solver has been deprecated, which in this case it's been moved to probdiffeq.util.test_util. It remains a valuable tool for unittests, but I don't recommend using it instead of jittable solve_adaptive_save_at or other methods.
IVPSolution type removed: Now, the outputs of the ivpsolve routines depend on the solution type of the solver. since the probabilistic solvers operate on ProbabilisticSolution types (during the forward pass), these are returned. They are almost identical to the previous IVPSolution type.
Documentation improved: Massive overhaul of docstrings and typing (clearer structure is easier to type). As a part of this effort, all base classes (like
MarkovStrategy) have been made public to communicate the code structure more openlyCalibration built-in: Removed
stats.calibrate; calibration now happens in thefinalize()methods of strategies.ivpsolve-methods refactored: The solution routines in probdiffeq.ivpsolve.py now cleanly separate static (strings, classes, callable, integers, etc.) from dynamic arguments (pytrees of arrays): instead of solve(dynamic1, dynamic2, static1, static2), the solution routines implement solve = solve_factory(static1, static2); solution = solve(dynamic1, dynamic2), which makes the code much easier to jit/vmap/differentiate.
While loops more accessible: While-loop selection (eg equinox's bounded while loop) is now governed by an argument to the solve-routines. No more context managers.
Controller types public: The controller types have been made public to clarify the API
Backend modules renamed Some of the more verbose backend module names have been shortened to streamling the src a bit
What gives?
According to the benchmarks, the code is a little bit slower than before, presumably because some of the "user-friendliness" is now built in (e.g., computing marginals of the smoothing solution, or always returning standard deviations at each step during the forward pass instead of computing them once afterwards). However, the code still has the same "big-Oh" complexity, and the wall-time differences shouldn't really be noticeable outside of dedicated benchmarks.
To make up for that: