Skip to content

Commit 0b46bad

Browse files
authored
Improve some documentation and make minor usability adjustments (#816)
* Add an error message for non-vector-valued initial conditions * Improve the visuals of the ODE in the quickstart example * Update the ruff-pre-commit hook * Treat warnings as errors (and ditch Diffrax from tests because it raises warnings) * Allow passing strategies to solvers as keyword arguments
1 parent 635669f commit 0b46bad

10 files changed

Lines changed: 45 additions & 74 deletions

File tree

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ repos:
1313
- id: yamlfix
1414
- repo: https://github.com/astral-sh/ruff-pre-commit
1515
# Ruff version.
16-
rev: v0.12.10
16+
rev: v0.12.11
1717
hooks:
1818
# Run the linter.
19-
- id: ruff
19+
- id: ruff-check
2020
args: [--fix]
2121
# Run the formatter.
2222
- id: ruff-format

docs/examples_quickstart/easy_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@
3636
@jax.jit
3737
def vf(y, *, t): # noqa: ARG001
3838
"""Evaluate the vector field."""
39-
return 0.5 * y * (1 - y)
39+
return 2.0 * y * (1 - y)
4040

4141

4242
u0 = jnp.asarray([0.1])
43-
t0, t1 = 0.0, 1.0
43+
t0, t1 = 0.0, 5.0
4444
# -
4545

4646
# Configuring a probabilistic IVP solver is a little more

makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ format-and-lint:
22
pre-commit run --all-files
33

44
test:
5-
pytest -n auto -v # parallelise, verbose output
5+
pytest -n auto -v -Werror # parallelise, verbose output, warnings as errors
66

77
quickstart:
88
# Run some code without installing any of the optional dependencies

probdiffeq/backend/ode.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import jax
44
import jax.experimental.ode
5-
import jax.numpy as jnp
65

76

87
def odeint_and_save_at(vf, y0: tuple, /, save_at, *, atol, rtol):
@@ -15,40 +14,6 @@ def vf_wrapped(y, t):
1514
return jax.experimental.ode.odeint(vf_wrapped, *y0, save_at, atol=atol, rtol=rtol)
1615

1716

18-
def odeint_dense(vf, y0: tuple, /, t0, t1, *, atol, rtol):
19-
# Local import because diffrax is not an official dependency
20-
import diffrax
21-
22-
assert isinstance(y0, tuple | list)
23-
assert len(y0) == 1
24-
25-
@diffrax.ODETerm
26-
@jax.jit
27-
def vf_wrapped(t, y, _args):
28-
return vf(y, t=t)
29-
30-
solution_object = diffrax.diffeqsolve(
31-
vf_wrapped,
32-
diffrax.Dopri5(),
33-
t0=t0,
34-
t1=t1,
35-
dt0=0.1,
36-
y0=y0[0],
37-
saveat=diffrax.SaveAt(dense=True),
38-
stepsize_controller=diffrax.PIDController(atol=atol, rtol=rtol),
39-
)
40-
41-
def solution(t):
42-
# Automatic batching
43-
if jnp.ndim(t) > 0:
44-
return jax.vmap(solution)(t)
45-
46-
# Interpolate
47-
return solution_object.evaluate(t)
48-
49-
return solution
50-
51-
5217
def ivp_lotka_volterra():
5318
# Local imports because diffeqzoo is not an official dependency
5419
from diffeqzoo import backend, ivps

probdiffeq/impl/impl.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""State-space model implementations."""
22

33
from probdiffeq.backend import containers, functools, tree_util
4+
from probdiffeq.backend import numpy as np
45
from probdiffeq.impl import _conditional, _linearise, _normal, _prototypes, _stats
56

67

@@ -26,6 +27,12 @@ def __eq__(self, other):
2627

2728
def choose(which: str, /, *, tcoeffs_like) -> FactImpl:
2829
"""Choose a state-space model implementation."""
30+
u0 = np.asarray(tcoeffs_like[0])
31+
if u0.ndim != 1:
32+
msg = "'tcoeffs' expected to have shape=(d,), "
33+
msg += f"but shape={u0.shape} received."
34+
raise ValueError(msg)
35+
2936
if which == "dense":
3037
return _select_dense(tcoeffs_like=tcoeffs_like)
3138
if which == "isotropic":

probdiffeq/ivpsolvers.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def initial_condition(self):
766766
return posterior, self.prior.output_scale
767767

768768

769-
def solver_mle(extrapolation, /, *, correction, prior, ssm):
769+
def solver_mle(strategy, *, correction, prior, ssm):
770770
"""Create a solver that calibrates the output scale via maximum-likelihood.
771771
772772
Warning: needs to be combined with a call to stats.calibrate()
@@ -777,15 +777,15 @@ def step_mle(state, /, *, dt, vector_field, calibration):
777777
output_scale_prior, _calibrated = calibration.extract(state.output_scale)
778778

779779
prior_discretized = prior.discretize(dt)
780-
hidden, extra = extrapolation.begin(
780+
hidden, extra = strategy.begin(
781781
state.hidden, state.aux_extra, prior_discretized=prior_discretized
782782
)
783783
t = state.t + dt
784784
error, _, corr = correction.estimate_error(
785785
hidden, vector_field=vector_field, t=t
786786
)
787787

788-
hidden, extra = extrapolation.complete(
788+
hidden, extra = strategy.complete(
789789
hidden, extra, output_scale=output_scale_prior
790790
)
791791
hidden, observed = correction.complete(hidden, corr)
@@ -800,7 +800,7 @@ def step_mle(state, /, *, dt, vector_field, calibration):
800800
prior=prior,
801801
calibration=_calibration_running_mean(ssm=ssm),
802802
step_implementation=step_mle,
803-
extrapolation=extrapolation,
803+
extrapolation=strategy,
804804
correction=correction,
805805
requires_rescaling=True,
806806
)
@@ -829,12 +829,12 @@ def extract(state, /):
829829
return _Calibration(init=init, update=update, extract=extract)
830830

831831

832-
def solver_dynamic(extrapolation, *, correction, prior, ssm):
832+
def solver_dynamic(strategy, *, correction, prior, ssm):
833833
"""Create a solver that calibrates the output scale dynamically."""
834834

835835
def step_dynamic(state, /, *, dt, vector_field, calibration):
836836
prior_discretized = prior.discretize(dt)
837-
hidden, extra = extrapolation.begin(
837+
hidden, extra = strategy.begin(
838838
state.hidden, state.aux_extra, prior_discretized=prior_discretized
839839
)
840840
t = state.t + dt
@@ -845,7 +845,7 @@ def step_dynamic(state, /, *, dt, vector_field, calibration):
845845
output_scale = calibration.update(state.output_scale, observed=observed)
846846

847847
prior_, _calibrated = calibration.extract(output_scale)
848-
hidden, extra = extrapolation.complete(hidden, extra, output_scale=prior_)
848+
hidden, extra = strategy.complete(hidden, extra, output_scale=prior_)
849849
hidden, corr = correction.complete(hidden, corr)
850850

851851
# Return solution
@@ -855,7 +855,7 @@ def step_dynamic(state, /, *, dt, vector_field, calibration):
855855
return _ProbabilisticSolver(
856856
prior=prior,
857857
ssm=ssm,
858-
extrapolation=extrapolation,
858+
extrapolation=strategy,
859859
correction=correction,
860860
calibration=_calibration_most_recent(ssm=ssm),
861861
name="Dynamic probabilistic solver",
@@ -877,22 +877,22 @@ def extract(state, /):
877877
return _Calibration(init=init, update=update, extract=extract)
878878

879879

880-
def solver(extrapolation, /, *, correction, prior, ssm):
880+
def solver(strategy, *, correction, prior, ssm):
881881
"""Create a solver that does not calibrate the output scale automatically."""
882882

883883
def step(state: _State, *, vector_field, dt, calibration):
884884
del calibration # unused
885885

886886
prior_discretized = prior.discretize(dt)
887-
hidden, extra = extrapolation.begin(
887+
hidden, extra = strategy.begin(
888888
state.hidden, state.aux_extra, prior_discretized=prior_discretized
889889
)
890890
t = state.t + dt
891891
error, _, corr = correction.estimate_error(
892892
hidden, vector_field=vector_field, t=t
893893
)
894894

895-
hidden, extra = extrapolation.complete(
895+
hidden, extra = strategy.complete(
896896
hidden, extra, output_scale=state.output_scale
897897
)
898898
hidden, corr = correction.complete(hidden, corr)
@@ -906,7 +906,7 @@ def step(state: _State, *, vector_field, dt, calibration):
906906
return _ProbabilisticSolver(
907907
ssm=ssm,
908908
prior=prior,
909-
extrapolation=extrapolation,
909+
extrapolation=strategy,
910910
correction=correction,
911911
calibration=_calibration_none(),
912912
step_implementation=step,

tests/test_ivpsolve/test_save_every_step.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,4 @@ def python_loop_solution(ivp, *, fact, strategy_fun):
5959

6060
def reference_solution(ivp, ts):
6161
vf, u0, (t0, t1) = ivp
62-
sol = ode.odeint_dense(vf, u0, t0=t0, t1=t1, atol=1e-10, rtol=1e-10)
63-
return sol(ts)
62+
return ode.odeint_and_save_at(vf, u0, save_at=ts, atol=1e-10, rtol=1e-10)

tests/test_ivpsolvers/test_calibration_dynamic_vs_mle.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ def test_exponential_approximated_well(fact):
2626
solver_kwargs = {"grid": grid, "solver": solver, "ssm": ssm}
2727
approximation = ivpsolve.solve_fixed_grid(*problem_args, **solver_kwargs)
2828

29-
solution = ode.odeint_dense(vf, u0, t0=t0, t1=t1, atol=1e-5, rtol=1e-5)
30-
rmse = _rmse(approximation.u[0][-1], solution(t1))
29+
solution = ode.odeint_and_save_at(
30+
vf, u0, save_at=np.asarray([t0, t1]), atol=1e-5, rtol=1e-5
31+
)
32+
rmse = _rmse(approximation.u[0][-1], solution[-1])
3133
assert rmse < 0.1
3234

3335

tests/test_ivpsolvers/test_corrections.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,15 @@ def fixture_solution(correction_impl, fact):
5858
)
5959

6060

61-
@testing.fixture(name="reference_solution")
62-
def fixture_reference_solution():
63-
vf, (u0,), (t0, t1) = ode.ivp_lotka_volterra()
64-
return ode.odeint_dense(vf, (u0,), t0=t0, t1=t1, atol=1e-10, rtol=1e-10)
65-
66-
67-
def test_terminal_value_simulation_matches_reference(solution, reference_solution):
61+
def test_terminal_value_simulation_matches_reference(solution):
6862
expected = reference_solution(solution.t)
6963
received = solution.u[0]
70-
7164
assert np.allclose(received, expected, rtol=1e-2)
65+
66+
67+
@functools.jit
68+
def reference_solution(t1):
69+
vf, (u0,), (t0, t1) = ode.ivp_lotka_volterra()
70+
ts = np.asarray([t0, t1])
71+
sol = ode.odeint_and_save_at(vf, (u0,), save_at=ts, atol=1e-10, rtol=1e-10)
72+
return sol[-1]

tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,10 @@ def fixture_smoother_solution(solver_setup):
4343
)
4444

4545

46-
@testing.fixture(name="reference_solution")
47-
def fixture_reference_solution():
48-
vf, (u0,), (t0, t1) = ode.ivp_lotka_volterra()
49-
return ode.odeint_dense(vf, (u0,), t0=t0, t1=t1, atol=1e-10, rtol=1e-10)
50-
51-
52-
def test_compare_filter_smoother_rmse(
53-
filter_solution, smoother_solution, reference_solution
54-
):
46+
def test_compare_filter_smoother_rmse(filter_solution, smoother_solution):
5547
assert np.allclose(filter_solution.t, smoother_solution.t) # sanity check
5648

57-
reference = reference_solution(filter_solution.t)
49+
reference = _reference_solution(filter_solution.t)
5850
filter_rmse = _rmse(filter_solution.u[0], reference)
5951
smoother_rmse = _rmse(smoother_solution.u[0], reference)
6052

@@ -66,5 +58,10 @@ def test_compare_filter_smoother_rmse(
6658
assert filter_rmse < 0.01
6759

6860

61+
def _reference_solution(ts):
62+
vf, (u0,), (t0, t1) = ode.ivp_lotka_volterra()
63+
return ode.odeint_and_save_at(vf, (u0,), save_at=ts, atol=1e-10, rtol=1e-10)
64+
65+
6966
def _rmse(a, b):
7067
return linalg.vector_norm((a - b) / b) / np.sqrt(b.size)

0 commit comments

Comments
 (0)