Skip to content

Commit 87f4178

Browse files
authored
Implement support for pytree-valued differential equations (#833)
* Taylor series estimation supports arbitrary pytrees now * Run Taylor tests for all state-space model factorisations * Start making the solvers pytree-compatible * I need a break... * Only left to update the stats (and the other factorisations) * Update stats * Fix terminal-value-LML tests * Update LML code * Update remaining stats modules * Remove print statements from the src * Fix the last dense test * Update isotropic implementations * All tests pass for isotropic & dense implementations * Update blockdiagonal implementation * Rerun benchmarks * Change taylor-coeff notebook to logistic ODE to demonstrate scalar-valued problems * Delete outdated comments * Clean up comments * Delete dead code * Improve formatting
1 parent edca937 commit 87f4178

39 files changed

Lines changed: 383 additions & 330 deletions

docs/benchmarks/hires/results.npy

0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

docs/examples_basic/taylor_coefficients.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
# We start by defining an ODE.
4141

4242
# +
43-
f, u0, (t0, t1), f_args = ivps.rigid_body()
43+
f, u0, (t0, t1), f_args = ivps.logistic()
4444

4545

4646
def vf(*y, t): # noqa: ARG001

makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ format-and-lint:
22
pre-commit run --all-files
33

44
test:
5-
pytest -n auto -v -Werror # parallelise, verbose output, warnings as errors
5+
pytest -n auto -v -Werror # parallelise, verbose output, warnings as errors
66

77
quickstart:
88
# Run some code without installing any of the optional dependencies

probdiffeq/backend/numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,11 @@ def eye(n, m=None, /):
133133

134134

135135
def save(path, arr, /):
136-
return jnp.save(path, arr)
136+
return jnp.save(path, arr, allow_pickle=True)
137137

138138

139139
def load(path, /):
140-
return jnp.load(path)
140+
return jnp.load(path, allow_pickle=True)
141141

142142

143143
def allclose(a, b, *, atol=1e-8, rtol=1e-5):

0 commit comments

Comments
 (0)