Skip to content

Commit 373fe21

Browse files
authored
Rename prior_ibm to prior_wiener_integrated and simplify the initial state (#827)
* Proof of concept: new API for integrated Wiener process * Update most of the src * Fix remaining tests * Update examples * Update examples and benchmarks * Remove unused code * Update pre-commit hook * Rename variables to be more similar to previous versions * Remove Pandoc installation from linter * Remove pandoc installations
1 parent bfc8b93 commit 373fe21

40 files changed

Lines changed: 149 additions & 256 deletions

.github/workflows/ci.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ jobs:
2626
python-version: ${{ matrix.python-version }}
2727
- name: Install dependencies
2828
run: |
29-
sudo apt-get install pandoc
3029
pip install --upgrade pip
3130
pip install .[cpu,format-and-lint]
3231
- name: Apply linter
@@ -111,7 +110,6 @@ jobs:
111110
python-version: ${{ matrix.python-version }}
112111
- name: Install dependencies
113112
run: |
114-
sudo apt-get install pandoc
115113
pip install --upgrade pip
116114
pip install .[cpu,doc]
117115
- name: Build the HTML docs

.github/workflows/doc-publish.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ jobs:
2626
python-version: ${{ matrix.python-version }}
2727
- name: Install dependencies
2828
run: |
29-
sudo apt-get install pandoc
3029
pip install --upgrade pip
3130
pip install .[cpu,doc]
3231
- name: Build the HTML docs

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ repos:
88
- id: end-of-file-fixer
99
- id: check-merge-conflict
1010
- repo: https://github.com/lyz-code/yamlfix/
11-
rev: 1.17.0
11+
rev: 1.18.0
1212
hooks:
1313
- id: yamlfix
1414
- repo: https://github.com/astral-sh/ruff-pre-commit

docs/benchmarks/hires/run_hires.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def param_to_solution(tol):
8989
# Build a solver
9090
vf_auto = functools.partial(vf_probdiffeq, t=t0)
9191
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives)
92-
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense")
92+
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
9393
ts1 = ivpsolvers.correction_ts1(ssm=ssm)
9494
strategy = ivpsolvers.strategy_filter(ssm=ssm)
9595
solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm)
@@ -98,9 +98,6 @@ def param_to_solution(tol):
9898
solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm, clip_dt=True
9999
)
100100

101-
# Initial state
102-
init = solver.initial_condition()
103-
104101
# Solve
105102
dt0 = ivpsolve.dt0(vf_auto, (u0,))
106103
solution = ivpsolve.solve_adaptive_terminal_values(

docs/benchmarks/lotkavolterra/run_lotkavolterra.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def param_to_solution(tol):
8080
vf_auto = functools.partial(vf_probdiffeq, t=t0)
8181
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives)
8282

83-
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=implementation)
83+
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
84+
tcoeffs, ssm_fact=implementation
85+
)
8486
strategy = ivpsolvers.strategy_filter(ssm=ssm)
8587
corr = correction(ssm=ssm)
8688
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=corr, ssm=ssm)
@@ -89,9 +91,6 @@ def param_to_solution(tol):
8991
solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm
9092
)
9193

92-
# Initial state
93-
init = solver.initial_condition()
94-
9594
# Solve
9695
dt0 = ivpsolve.dt0(vf_auto, (u0,))
9796
solution = ivpsolve.solve_adaptive_terminal_values(

docs/benchmarks/pleiades/run_pleiades.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def param_to_solution(tol):
9999
vf_auto = functools.partial(vf_probdiffeq, t=t0)
100100
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0, du0), num=num_derivatives - 1)
101101

102-
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic")
102+
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
103+
tcoeffs, ssm_fact="isotropic"
104+
)
103105
ts0_or_ts1 = correction_fun(ssm=ssm, ode_order=2)
104106
strategy = ivpsolvers.strategy_filter(ssm=ssm)
105107
solver = ivpsolvers.solver_dynamic(
@@ -110,9 +112,6 @@ def param_to_solution(tol):
110112
solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm
111113
)
112114

113-
# Initial state
114-
init = solver.initial_condition()
115-
116115
# Solve
117116
dt0 = ivpsolve.dt0(vf_auto, (u0, du0))
118117
solution = ivpsolve.solve_adaptive_terminal_values(

docs/benchmarks/vanderpol/run_vanderpol.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def param_to_solution(tol):
8181
vf_auto = functools.partial(vf_probdiffeq, t=t0)
8282
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0, du0), num=num_derivatives - 1)
8383

84-
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense")
84+
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")
8585
ts0_or_ts1 = ivpsolvers.correction_ts1(ode_order=2, ssm=ssm)
8686
strategy = ivpsolvers.strategy_filter(ssm=ssm)
8787

@@ -93,9 +93,6 @@ def param_to_solution(tol):
9393
solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm, clip_dt=True
9494
)
9595

96-
# Initial state
97-
init = solver.initial_condition()
98-
9996
# Solve
10097
dt0 = ivpsolve.dt0(vf_auto, (u0, du0))
10198
solution = ivpsolve.solve_adaptive_terminal_values(

docs/examples_advanced/equinox_while_loop.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,12 @@ def vf(y, *, t): # noqa: ARG001
6161
u0 = jnp.asarray([0.1])
6262

6363
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
64-
ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic")
64+
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
6565
ts0 = ivpsolvers.correction_ts0(ode_order=1, ssm=ssm)
6666

6767
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
6868
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
6969
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
70-
init = solver.initial_condition()
7170

7271
def simulate(init_val):
7372
"""Evaluate the parameter-to-solution function."""

docs/examples_advanced/neural_ode.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,14 @@ def loss(
149149
"""Loss function: log-marginal likelihood of the data."""
150150
# Build a solver
151151
tcoeffs = (*u0, vf(*u0, t=t0, p=p))
152-
ibm, ssm = ivpsolvers.prior_ibm(
152+
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
153153
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
154154
)
155155
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
156156
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
157157
solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
158158

159159
# Solve
160-
init = solver_ts0.initial_condition()
161160
sol = ivpsolve.solve_fixed_grid(
162161
lambda *a, **kw: vf(*a, **kw, p=p),
163162
init,

docs/examples_advanced/parameter_estimation_blackjax.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,12 @@ def solve_fixed(theta, *, ts):
183183
# Create a probabilistic solver
184184
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (theta,), num=2)
185185
output_scale = 10.0
186-
ibm, ssm = ivpsolvers.prior_ibm(
186+
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
187187
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
188188
)
189189
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
190190
strategy = ivpsolvers.strategy_filter(ssm=ssm)
191191
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
192-
init = solver.initial_condition()
193192
return ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver, ssm=ssm)
194193

195194

@@ -199,15 +198,13 @@ def solve_adaptive(theta, *, save_at):
199198
# Create a probabilistic solver
200199
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (theta,), num=2)
201200
output_scale = 10.0
202-
ibm, ssm = ivpsolvers.prior_ibm(
201+
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
203202
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
204203
)
205204
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
206205
strategy = ivpsolvers.strategy_filter(ssm=ssm)
207206
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
208207
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
209-
210-
init = solver.initial_condition()
211208
return ivpsolve.solve_adaptive_save_at(
212209
vf, init, save_at=save_at, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
213210
)

0 commit comments

Comments
 (0)