Skip to content

Commit 326401e

Browse files
committed
Undo some solution typing changes
1 parent 6be9096 commit 326401e

8 files changed

Lines changed: 44 additions & 48 deletions

File tree

docs/examples_advanced/neural_ode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def loss(
170170
marginal_likelihood = strategy.log_marginal_likelihood(
171171
data[:, None],
172172
standard_deviation=jnp.ones_like(grid) * stdev,
173-
posterior=sol.full_solution,
173+
posterior=sol.solution_full,
174174
)
175175
return -1 * marginal_likelihood, {"sol": sol}
176176

docs/examples_basic/conditioning_on_zero_residual.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def vector_field(y, t): # noqa: ARG001
8080
dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,))
8181
solve = ivpsolve.solve_adaptive_save_at(solver=solver, errorest=errorest)
8282
sol = solve(init, save_at=ts, dt0=dt0, atol=1e-1, rtol=1e-1)
83-
markov_seq_posterior = sol.full_solution
83+
markov_seq_posterior = sol.solution_full
8484

8585

8686
# +

probdiffeq/ivpsolve.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,9 @@ class Solution(Protocol, Generic[S]):
3838
num_steps: int
3939
"""The number of steps taken by the solver."""
4040

41-
full_solution: Any
41+
solution_full: Any
4242
"""A full description of the solution (beyond 'u', e.g. for dense outputs)."""
4343

44-
hyperparams: Any
45-
"""A description of (calibrated) hyperparameters."""
46-
4744

4845
# Revisit this dependent typing one Python >=3.12 is enforced
4946
# Concretely, Something like Solver[T, S: Solution[T]](Protocol):...

probdiffeq/probdiffeq.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -513,10 +513,10 @@ class ProbabilisticSolution(Generic[C, T]):
513513
u: TaylorCoeffTarget[C, T]
514514
"""The current ODE solution estimate."""
515515

516-
full_solution: T | MarkovSequence[T]
516+
solution_full: T | MarkovSequence[T]
517517
"""The current posterior estimate."""
518518

519-
# Todo: merge 'output_scale' and 'auxiliary'?
519+
# Todo: merge 'output_scale' and 'auxiliary' and "fun_evals"?
520520
output_scale: Any
521521
"""The current output scale."""
522522

@@ -634,7 +634,7 @@ def offgrid_marginals(self, t, *, solution):
634634
def _extract_previous(pytree):
635635
return tree.tree_map(lambda s: s[index - 1, ...], pytree)
636636

637-
posterior_t0 = _extract_previous(solution.full_solution)
637+
posterior_t0 = _extract_previous(solution.solution_full)
638638
t0 = _extract_previous(solution.t)
639639

640640
# Extract the RHS
@@ -674,8 +674,8 @@ def interpolate(
674674

675675
# Interpolate
676676
tmp = self.strategy.interpolate(
677-
posterior_t0=interp_from.full_solution,
678-
posterior_t1=interp_to.full_solution,
677+
posterior_t0=interp_from.solution_full,
678+
posterior_t1=interp_to.solution_full,
679679
transition_t0_t=transition_t0_t,
680680
transition_t_t1=transition_t_t1,
681681
)
@@ -684,7 +684,7 @@ def interpolate(
684684
step_from = ProbabilisticSolution(
685685
t=interp_to.t,
686686
# New:
687-
full_solution=step_and_interpolate_from.step_from,
687+
solution_full=step_and_interpolate_from.step_from,
688688
# Old:
689689
u=interp_to.u,
690690
output_scale=interp_to.output_scale,
@@ -696,7 +696,7 @@ def interpolate(
696696
interpolated = ProbabilisticSolution(
697697
t=t,
698698
# New:
699-
full_solution=interpolated,
699+
solution_full=interpolated,
700700
u=estimate,
701701
# Taken from the rhs point
702702
output_scale=interp_to.output_scale,
@@ -708,7 +708,7 @@ def interpolate(
708708
interp_from = ProbabilisticSolution(
709709
t=t,
710710
# New:
711-
full_solution=step_and_interpolate_from.interp_from,
711+
solution_full=step_and_interpolate_from.interp_from,
712712
# Old:
713713
u=interp_from.u,
714714
output_scale=interp_from.output_scale,
@@ -725,13 +725,13 @@ def interpolate_at_t1(
725725
):
726726
"""Interpolate the solution near a checkpoint."""
727727
del t
728-
tmp = self.strategy.interpolate_at_t1(posterior_t1=interp_to.full_solution)
728+
tmp = self.strategy.interpolate_at_t1(posterior_t1=interp_to.solution_full)
729729
(estimate, interpolated), step_and_interpolate_from = tmp
730730

731731
prev = ProbabilisticSolution(
732732
t=interp_to.t,
733733
# New
734-
full_solution=step_and_interpolate_from.interp_from,
734+
solution_full=step_and_interpolate_from.interp_from,
735735
# Old
736736
u=interp_from.u, # incorrect?
737737
output_scale=interp_from.output_scale, # incorrect?
@@ -742,7 +742,7 @@ def interpolate_at_t1(
742742
sol = ProbabilisticSolution(
743743
t=interp_to.t,
744744
# New:
745-
full_solution=interpolated,
745+
solution_full=interpolated,
746746
u=estimate,
747747
# Old:
748748
output_scale=interp_to.output_scale,
@@ -753,7 +753,7 @@ def interpolate_at_t1(
753753
acc = ProbabilisticSolution(
754754
t=interp_to.t,
755755
# New:
756-
full_solution=step_and_interpolate_from.step_from,
756+
solution_full=step_and_interpolate_from.step_from,
757757
# Old
758758
u=interp_to.u,
759759
output_scale=interp_to.output_scale,
@@ -1339,7 +1339,7 @@ def init(self, t, u) -> ProbabilisticSolution:
13391339
return ProbabilisticSolution(
13401340
t=t,
13411341
u=estimate,
1342-
full_solution=posterior,
1342+
solution_full=posterior,
13431343
auxiliary=auxiliary,
13441344
output_scale=output_scale_prior,
13451345
num_steps=0,
@@ -1353,7 +1353,7 @@ def step(self, state, *, dt: float, damp: float):
13531353

13541354
# Predict
13551355
u, prediction = self.strategy.predict(
1356-
posterior=state.full_solution, transition=transition
1356+
posterior=state.solution_full, transition=transition
13571357
)
13581358

13591359
# Linearize
@@ -1381,7 +1381,7 @@ def step(self, state, *, dt: float, damp: float):
13811381
return ProbabilisticSolution(
13821382
t=state.t + dt,
13831383
u=u,
1384-
full_solution=posterior,
1384+
solution_full=posterior,
13851385
output_scale=state.output_scale,
13861386
auxiliary=auxiliary,
13871387
num_steps=state.num_steps + 1,
@@ -1398,8 +1398,8 @@ def userfriendly_output(
13981398
ones = np.ones_like(output_scale)
13991399
output_scale = output_scale[-1]
14001400

1401-
init = solution0.full_solution
1402-
posterior = solution.full_solution
1401+
init = solution0.solution_full
1402+
posterior = solution.solution_full
14031403
estimate, posterior = self.strategy.finalize(
14041404
posterior0=init, posterior=posterior, output_scale=output_scale
14051405
)
@@ -1409,7 +1409,7 @@ def userfriendly_output(
14091409
return ProbabilisticSolution(
14101410
t=ts,
14111411
u=estimate,
1412-
full_solution=posterior,
1412+
solution_full=posterior,
14131413
output_scale=output_scale,
14141414
num_steps=solution.num_steps,
14151415
auxiliary=solution.auxiliary,
@@ -1456,7 +1456,7 @@ def init(self, t, u) -> ProbabilisticSolution:
14561456
return ProbabilisticSolution(
14571457
t=t,
14581458
u=estimate,
1459-
full_solution=posterior,
1459+
solution_full=posterior,
14601460
auxiliary=lin_state,
14611461
output_scale=output_scale,
14621462
num_steps=0,
@@ -1484,7 +1484,7 @@ def step(self, state: ProbabilisticSolution, *, dt: float, damp: float):
14841484
# (Includes re-discretisation)
14851485
transition = self.prior(dt, output_scale)
14861486
u, prediction = self.strategy.predict(
1487-
state.full_solution, transition=transition
1487+
state.solution_full, transition=transition
14881488
)
14891489

14901490
# Relinearize
@@ -1502,7 +1502,7 @@ def step(self, state: ProbabilisticSolution, *, dt: float, damp: float):
15021502
return ProbabilisticSolution(
15031503
t=state.t + dt,
15041504
u=u,
1505-
full_solution=posterior,
1505+
solution_full=posterior,
15061506
num_steps=state.num_steps + 1,
15071507
auxiliary=lin_state,
15081508
output_scale=output_scale,
@@ -1517,8 +1517,8 @@ def userfriendly_output(
15171517
ones = np.ones_like(solution.output_scale)
15181518
output_scale = ones[-1, ...]
15191519

1520-
init = solution0.full_solution
1521-
posterior = solution.full_solution
1520+
init = solution0.solution_full
1521+
posterior = solution.solution_full
15221522
estimate, posterior = self.strategy.finalize(
15231523
posterior0=init, posterior=posterior, output_scale=output_scale
15241524
)
@@ -1529,7 +1529,7 @@ def userfriendly_output(
15291529
return ProbabilisticSolution(
15301530
t=ts,
15311531
u=estimate,
1532-
full_solution=posterior,
1532+
solution_full=posterior,
15331533
output_scale=output_scale,
15341534
num_steps=solution.num_steps,
15351535
auxiliary=solution.auxiliary,
@@ -1567,7 +1567,7 @@ def init(self, t: Array, u: TaylorCoeffTarget) -> ProbabilisticSolution:
15671567
return ProbabilisticSolution(
15681568
t=t,
15691569
u=u,
1570-
full_solution=posterior,
1570+
solution_full=posterior,
15711571
num_steps=0,
15721572
auxiliary=correction_state,
15731573
output_scale=output_scale,
@@ -1581,7 +1581,7 @@ def step(self, state: ProbabilisticSolution, *, dt, damp):
15811581

15821582
# Predict
15831583
u, prediction = self.strategy.predict(
1584-
state.full_solution, transition=transition
1584+
state.solution_full, transition=transition
15851585
)
15861586

15871587
# Linearize
@@ -1598,7 +1598,7 @@ def step(self, state: ProbabilisticSolution, *, dt, damp):
15981598
return ProbabilisticSolution(
15991599
t=state.t + dt,
16001600
u=u,
1601-
full_solution=posterior,
1601+
solution_full=posterior,
16021602
output_scale=output_scale,
16031603
auxiliary=auxiliary,
16041604
num_steps=state.num_steps + 1,
@@ -1614,8 +1614,8 @@ def userfriendly_output(
16141614
ones = np.ones_like(solution.output_scale)
16151615
output_scale = np.ones_like(solution.output_scale[-1])
16161616

1617-
init = solution0.full_solution
1618-
posterior = solution.full_solution
1617+
init = solution0.solution_full
1618+
posterior = solution.solution_full
16191619
u, posterior = self.strategy.finalize(
16201620
posterior0=init, posterior=posterior, output_scale=output_scale
16211621
)
@@ -1626,7 +1626,7 @@ def userfriendly_output(
16261626
return ProbabilisticSolution(
16271627
t=ts,
16281628
u=u,
1629-
full_solution=posterior,
1629+
solution_full=posterior,
16301630
output_scale=output_scale,
16311631
num_steps=solution.num_steps,
16321632
auxiliary=solution.auxiliary,

tests/test_probdiffeq/test_log_marginal_likelihood.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_output_is_a_scalar_and_not_nan_and_not_inf(solution):
3131
data = tree.tree_map(lambda s: s + 0.005, sol.u.mean[0])
3232
std = tree.tree_map(lambda _s: np.ones_like(sol.t), sol.u.std[0])
3333
lml = strategy.log_marginal_likelihood(
34-
data, standard_deviation=std, posterior=sol.full_solution
34+
data, standard_deviation=std, posterior=sol.solution_full
3535
)
3636
assert lml.shape == ()
3737
assert not np.isnan(lml)
@@ -49,7 +49,7 @@ def test_that_function_raises_error_for_wrong_std_shape_too_many(solution):
4949

5050
with testing.raises(ValueError, match="does not match"):
5151
_ = strategy.log_marginal_likelihood(
52-
data, standard_deviation=std, posterior=sol.full_solution
52+
data, standard_deviation=std, posterior=sol.solution_full
5353
)
5454

5555

@@ -63,7 +63,7 @@ def test_raises_error_for_terminal_values(solution):
6363
data = tree.tree_map(lambda s: s[-1] + 0.005, sol.u.mean[0])
6464
std = tree.tree_map(lambda _s: np.ones_like(sol.t[-1]), sol.u.std[0])
6565

66-
posterior_t1 = tree.tree_map(lambda s: s[-1], sol.full_solution)
66+
posterior_t1 = tree.tree_map(lambda s: s[-1], sol.solution_full)
6767
with testing.raises(ValueError, match="expected"):
6868
_ = strategy.log_marginal_likelihood(
6969
data, standard_deviation=std, posterior=posterior_t1
@@ -91,7 +91,7 @@ def test_raises_error_for_filter(fact):
9191
std = tree.tree_map(np.ones_like, sol.u.std[0])
9292
with testing.raises(TypeError, match="ilter"):
9393
_ = strategy.log_marginal_likelihood(
94-
data, standard_deviation=std, posterior=sol.full_solution
94+
data, standard_deviation=std, posterior=sol.solution_full
9595
)
9696

9797

@@ -102,5 +102,5 @@ def test_raise_error_if_structures_dont_match(solution):
102102

103103
with testing.raises(ValueError, match="tree structure"):
104104
_ = strategy.log_marginal_likelihood(
105-
data, standard_deviation=std, posterior=sol.full_solution
105+
data, standard_deviation=std, posterior=sol.solution_full
106106
)

tests/test_probdiffeq/test_log_marginal_likelihood_terminal_values.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_output_is_scalar_and_not_inf_and_not_nan(solution):
4949
std = tree.tree_map(lambda _s: 1e-2 * np.ones(()), sol.u.std[0])
5050

5151
mll = strategy.log_marginal_likelihood_terminal_values(
52-
data, standard_deviation=std, posterior=sol.full_solution
52+
data, standard_deviation=std, posterior=sol.solution_full
5353
)
5454

5555
assert mll.shape == ()
@@ -64,5 +64,5 @@ def test_raise_error_if_structures_dont_match(solution):
6464

6565
with testing.raises(ValueError, match="structure"):
6666
_ = strategy.log_marginal_likelihood_terminal_values(
67-
data, standard_deviation=std, posterior=sol.full_solution
67+
data, standard_deviation=std, posterior=sol.solution_full
6868
)

tests/test_probdiffeq/test_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_sample_shape(approximation_and_strategy, shape):
3030

3131
key = random.prng_key(seed=15)
3232
samples = strategy.markov_sample(
33-
key, approximation.full_solution, shape=shape, reverse=True
33+
key, approximation.solution_full, shape=shape, reverse=True
3434
)
3535
for s, u in zip(samples, approximation.u.mean):
3636
s_shape = tree.tree_map(lambda x: x.shape, s)

tests/test_probdiffeq/test_strategy_smoother_fixedinterval_vs_fixedpoint.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,16 @@ def test_fixedpoint_smoother_equivalent_same_grid(solver_setup, solution_smoothe
5252
assert testing.allclose(sol_fp.u.mean, sol_sm.u.mean)
5353
assert testing.allclose(sol_fp.u.std, sol_sm.u.std)
5454
assert testing.allclose(sol_fp.u.marginals, sol_sm.u.marginals)
55-
assert testing.allclose(sol_fp.output_scale, sol_sm.output_scale)
5655
assert testing.allclose(sol_fp.num_steps, sol_sm.num_steps)
5756
assert testing.allclose(
58-
sol_fp.full_solution.marginal, sol_sm.full_solution.marginal
57+
sol_fp.solution_full.marginal, sol_sm.solution_full.marginal
5958
)
6059

6160
# The backward conditionals use different parametrisations
6261
# but implement the same transitions
6362
cond_fp, cond_sm = (
64-
sol_fp.full_solution.conditional,
65-
sol_sm.full_solution.conditional,
63+
sol_fp.solution_full.conditional,
64+
sol_sm.solution_full.conditional,
6665
)
6766
cond_fp = func.vmap(ssm.conditional.preconditioner_apply)(cond_fp)
6867
cond_sm = func.vmap(ssm.conditional.preconditioner_apply)(cond_sm)

0 commit comments

Comments
 (0)