Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt-get install pandoc
pip install --upgrade pip
pip install .[cpu,format-and-lint]
- name: Apply linter
Expand Down Expand Up @@ -111,7 +110,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt-get install pandoc
pip install --upgrade pip
pip install .[cpu,doc]
- name: Build the HTML docs
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/doc-publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt-get install pandoc
pip install --upgrade pip
pip install .[cpu,doc]
- name: Build the HTML docs
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repos:
- id: end-of-file-fixer
- id: check-merge-conflict
- repo: https://github.com/lyz-code/yamlfix/
rev: 1.17.0
rev: 1.18.0
hooks:
- id: yamlfix
- repo: https://github.com/astral-sh/ruff-pre-commit
Expand Down
5 changes: 1 addition & 4 deletions docs/benchmarks/hires/run_hires.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def param_to_solution(tol):
# Build a solver
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives)
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense")
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm)
Expand All @@ -98,9 +98,6 @@ def param_to_solution(tol):
solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm, clip_dt=True
)

# Initial state
init = solver.initial_condition()

# Solve
dt0 = ivpsolve.dt0(vf_auto, (u0,))
solution = ivpsolve.solve_adaptive_terminal_values(
Expand Down
7 changes: 3 additions & 4 deletions docs/benchmarks/lotkavolterra/run_lotkavolterra.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def param_to_solution(tol):
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives)

ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=implementation)
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, ssm_fact=implementation
)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
corr = correction(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=corr, ssm=ssm)
Expand All @@ -89,9 +91,6 @@ def param_to_solution(tol):
solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm
)

# Initial state
init = solver.initial_condition()

# Solve
dt0 = ivpsolve.dt0(vf_auto, (u0,))
solution = ivpsolve.solve_adaptive_terminal_values(
Expand Down
7 changes: 3 additions & 4 deletions docs/benchmarks/pleiades/run_pleiades.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def param_to_solution(tol):
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0, du0), num=num_derivatives - 1)

ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic")
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, ssm_fact="isotropic"
)
ts0_or_ts1 = correction_fun(ssm=ssm, ode_order=2)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver_dynamic(
Expand All @@ -110,9 +112,6 @@ def param_to_solution(tol):
solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm
)

# Initial state
init = solver.initial_condition()

# Solve
dt0 = ivpsolve.dt0(vf_auto, (u0, du0))
solution = ivpsolve.solve_adaptive_terminal_values(
Expand Down
5 changes: 1 addition & 4 deletions docs/benchmarks/vanderpol/run_vanderpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def param_to_solution(tol):
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0, du0), num=num_derivatives - 1)

ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense")
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
ts0_or_ts1 = ivpsolvers.correction_ts1(ode_order=2, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)

Expand All @@ -93,9 +93,6 @@ def param_to_solution(tol):
solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm, clip_dt=True
)

# Initial state
init = solver.initial_condition()

# Solve
dt0 = ivpsolve.dt0(vf_auto, (u0, du0))
solution = ivpsolve.solve_adaptive_terminal_values(
Expand Down
3 changes: 1 addition & 2 deletions docs/examples_advanced/equinox_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,12 @@ def vf(y, *, t): # noqa: ARG001
u0 = jnp.asarray([0.1])

tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic")
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
ts0 = ivpsolvers.correction_ts0(ode_order=1, ssm=ssm)

strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
init = solver.initial_condition()

def simulate(init_val):
"""Evaluate the parameter-to-solution function."""
Expand Down
3 changes: 1 addition & 2 deletions docs/examples_advanced/neural_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,14 @@ def loss(
"""Loss function: log-marginal likelihood of the data."""
# Build a solver
tcoeffs = (*u0, vf(*u0, t=t0, p=p))
ibm, ssm = ivpsolvers.prior_ibm(
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)

# Solve
init = solver_ts0.initial_condition()
sol = ivpsolve.solve_fixed_grid(
lambda *a, **kw: vf(*a, **kw, p=p),
init,
Expand Down
7 changes: 2 additions & 5 deletions docs/examples_advanced/parameter_estimation_blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,12 @@ def solve_fixed(theta, *, ts):
# Create a probabilistic solver
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (theta,), num=2)
output_scale = 10.0
ibm, ssm = ivpsolvers.prior_ibm(
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
init = solver.initial_condition()
return ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver, ssm=ssm)


Expand All @@ -199,15 +198,13 @@ def solve_adaptive(theta, *, save_at):
# Create a probabilistic solver
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (theta,), num=2)
output_scale = 10.0
ibm, ssm = ivpsolvers.prior_ibm(
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)

init = solver.initial_condition()
return ivpsolve.solve_adaptive_save_at(
vf, init, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
)
Expand Down
4 changes: 1 addition & 3 deletions docs/examples_advanced/parameter_estimation_optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,12 @@ def solve(p):
"""Evaluate the parameter-to-solution map."""
tcoeffs = (u0, vf(u0, t0, p=p))
output_scale = 10.0
ibm, ssm = ivpsolvers.prior_ibm(
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)

init = solver.initial_condition()
return ivpsolve.solve_fixed_grid(
lambda y, t: vf(y, t, p=p), init, grid=ts, solver=solver, ssm=ssm
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def vector_field(y, t): # noqa: ARG001
NUM_DERIVATIVES = 2
tcoeffs_like = [u0] * (NUM_DERIVATIVES + 1)
ts = jnp.linspace(t0, t1, num=500, endpoint=True)
(init_raw, transitions), ssm = ivpsolvers.prior_ibm_discrete(
init_raw, transitions, ssm = ivpsolvers.prior_wiener_integrated_discrete(
ts, tcoeffs_like=tcoeffs_like, output_scale=100.0, ssm_fact="dense"
)

Expand All @@ -71,19 +71,18 @@ def vector_field(y, t): # noqa: ARG001
# +
# Compute the posterior

ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="dense")
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=1.0, ssm_fact="dense"
)
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts1, ssm=ssm)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-1, rtol=1e-2, ssm=ssm)

dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,))

init = solver.initial_condition()
sol = ivpsolve.solve_adaptive_save_at(
vector_field, init, save_at=ts, dt0=1.0, adaptive_solver=adaptive_solver, ssm=ssm
)
# posterior = stats.calibrate(sol.posterior, sol.output_scale)
markov_seq_posterior = stats.markov_select_terminal(sol.posterior)

# +
Expand Down
12 changes: 5 additions & 7 deletions docs/examples_basic/dynamic_output_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def vf(*ys, t): # noqa: ARG001
num_derivatives = 1

tcoeffs = (u0, vf(u0, t=t0))
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="dense")
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=1.0, ssm_fact="dense"
)
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
dynamic = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm)
Expand All @@ -77,12 +79,8 @@ def vf(*ys, t): # noqa: ARG001
ts = jnp.linspace(t0, t1, num=num_pts, endpoint=True)


init_mle = mle.initial_condition()
init_dynamic = dynamic.initial_condition()
solution_dynamic = ivpsolve.solve_fixed_grid(
vf, init_mle, grid=ts, solver=dynamic, ssm=ssm
)
solution_mle = ivpsolve.solve_fixed_grid(vf, init_dynamic, grid=ts, solver=mle, ssm=ssm)
solution_dynamic = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=dynamic, ssm=ssm)
solution_mle = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=mle, ssm=ssm)
# -

# Plot the solution.
Expand Down
3 changes: 1 addition & 2 deletions docs/examples_basic/posterior_uncertainties.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,14 @@ def vf(y, *, t): # noqa: ARG001
# Set up a solver
# To all users: Try replacing the fixedpoint-smoother with a filter!
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=3)
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense")
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
ts = ivpsolvers.correction_ts1(ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts, ssm=ssm)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-1, rtol=1e-1, ssm=ssm)

# Solve the ODE
ts = jnp.linspace(t0, t1, endpoint=True, num=50)
init = solver.initial_condition()
sol = ivpsolve.solve_adaptive_save_at(
vf, init, save_at=ts, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
)
Expand Down
11 changes: 6 additions & 5 deletions docs/examples_basic/second_order_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def vf_1(y, t): # noqa: ARG001


tcoeffs = taylor.odejet_padded_scan(lambda y: vf_1(y, t=t0), (u0,), num=4)
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=1.0, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver_1st = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)
Expand All @@ -53,7 +55,6 @@ def vf_1(y, t): # noqa: ARG001

# -

init = solver_1st.initial_condition()
solution = ivpsolve.solve_adaptive_save_every_step(
vf_1, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver_1st, ssm=ssm
)
Expand All @@ -78,14 +79,14 @@ def vf_2(y, dy, t): # noqa: ARG001

# One derivative more than above because we don't transform to first order
tcoeffs = taylor.odejet_padded_scan(lambda *ys: vf_2(*ys, t=t0), (u0, du0), num=3)
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic")
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=1.0, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ode_order=2, ssm=ssm)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver_2nd = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm)
adaptive_solver_2nd = ivpsolvers.adaptive(solver_2nd, atol=1e-5, rtol=1e-5, ssm=ssm)


init = solver_2nd.initial_condition()
# -

solution = ivpsolve.solve_adaptive_save_every_step(
Expand Down
4 changes: 1 addition & 3 deletions docs/examples_basic/taylor_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,10 @@ def vf(*y, t): # noqa: ARG001
# +
def solve(tc):
"""Solve the ODE."""
prior, ssm = ivpsolvers.prior_ibm(tc, ssm_fact="dense")
init, prior, ssm = ivpsolvers.prior_wiener_integrated(tc, ssm_fact="dense")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm)
init = solver.initial_condition()

ts = jnp.linspace(t0, t1, endpoint=True, num=10)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)
return ivpsolve.solve_adaptive_save_at(
Expand Down
3 changes: 1 addition & 2 deletions docs/examples_quickstart/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def vf(y, *, t): # noqa: ARG001

# Set up a state-space model
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense")
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")


# Build a solver
Expand All @@ -51,7 +51,6 @@ def vf(y, *, t): # noqa: ARG001

# Solve the ODE
# To all users: Try different solution routines.
init = solver.initial_condition()
solution = ivpsolve.solve_adaptive_save_every_step(
vf, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
)
Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ nav:
- Choosing a solver: choosing_a_solver.md
- Troubleshooting: troubleshooting.md
- EXAMPLES | BASIC:
- examples_basic/conditioning-on-zero-residual.ipynb
- examples_basic/conditioning_on_zero_residual.ipynb
- examples_basic/posterior_uncertainties.ipynb
- examples_basic/dynamic_output_scales.ipynb
- examples_basic/second_order_problems.ipynb
Expand Down
7 changes: 0 additions & 7 deletions probdiffeq/backend/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ def qr_r_jvp(primals, tangents):
# All Cholesky factors are lower-triangular by default


def cholesky_factor(arr, /):
return jnp.linalg.cholesky(arr)


# All Cholesky factors are lower-triangular by default


def cholesky_solve(arr, rhs, /):
return jax.scipy.linalg.cho_solve((arr, True), rhs)

Expand Down
4 changes: 0 additions & 4 deletions probdiffeq/backend/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ def squeeze(arr, /):
return jnp.squeeze(arr)


def squeeze_along_axis(arr, /, *, axis):
return jnp.squeeze(arr, axis=axis)


def atleast_1d(arr, /):
return jnp.atleast_1d(arr)

Expand Down
2 changes: 1 addition & 1 deletion probdiffeq/backend/typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Typing module."""

from collections.abc import Callable, Sequence
from typing import Any, Generic, Optional, TypeAlias, TypeVar
from typing import Any, Generic, TypeAlias, TypeVar

import jax
from mypy_extensions import NamedArg
Expand Down
Loading