Skip to content

Commit 29e9bcd

Browse files
committed
Initial condition returns an IVPSolution object now
1 parent a5746f0 commit 29e9bcd

2 files changed

Lines changed: 21 additions & 13 deletions

File tree

probdiffeq/ivpsolve.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def solve_adaptive_save_at(
148148

149149
# I think the user expects the initial condition to be part of the state
150150
# (as well as marginals), so we compute those things here
151-
posterior_t0, *_ = initial_condition
151+
posterior_t0 = initial_condition.posterior
152152
posterior_save_at, output_scale = solution_save_at
153153
_tmp = _userfriendly_output(
154154
posterior=posterior_save_at, posterior_t0=posterior_t0, ssm=ssm
@@ -238,7 +238,7 @@ def solve_adaptive_save_every_step(
238238
t = np.concatenate((np.atleast_1d(t0), t))
239239

240240
# I think the user expects marginals, so we compute them here
241-
posterior_t0, *_ = initial_condition
241+
posterior_t0 = initial_condition.posterior
242242
posterior, output_scale = solution_every_step
243243
_tmp = _userfriendly_output(posterior=posterior, posterior_t0=posterior_t0, ssm=ssm)
244244
marginals, posterior = _tmp
@@ -296,7 +296,8 @@ def body_fn(s, dt):
296296
_t, (posterior, output_scale) = solver.extract(result_state)
297297

298298
# I think the user expects marginals, so we compute them here
299-
posterior_t0, *_ = initial_condition
299+
# posterior_t0, *_ = initial_condition
300+
posterior_t0 = initial_condition.posterior
300301
_tmp = _userfriendly_output(posterior=posterior, posterior_t0=posterior_t0, ssm=ssm)
301302
marginals, posterior = _tmp
302303

probdiffeq/ivpsolvers.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Probabilistic IVP solvers."""
22

3-
from probdiffeq import stats
3+
from probdiffeq import ivpsolve, stats
44
from probdiffeq.backend import (
55
containers,
66
control_flow,
@@ -694,13 +694,25 @@ def is_suitable_for_save_at(self):
694694
def is_suitable_for_save_every_step(self):
695695
return self.strategy.is_suitable_for_save_every_step
696696

697-
def init(self, t, initial_condition) -> _State:
698-
posterior, output_scale = initial_condition
697+
def initial_condition(self) -> ivpsolve.IVPSolution:
698+
"""Construct an initial condition."""
699+
posterior = self.strategy.initial_condition(prior=self.prior)
700+
return ivpsolve.IVPSolution(
701+
t=None,
702+
u=None,
703+
u_std=None,
704+
output_scale=self.prior.output_scale,
705+
marginals=None,
706+
posterior=posterior,
707+
num_steps=None,
708+
ssm=None,
709+
)
699710

700-
rv, extra = self.strategy.init(posterior)
711+
def init(self, t, initial_condition: ivpsolve.IVPSolution) -> _State:
712+
rv, extra = self.strategy.init(initial_condition.posterior)
701713
rv, corr = self.correction.init(rv)
702714

703-
calib_state = self.calibration.init(output_scale)
715+
calib_state = self.calibration.init(initial_condition.output_scale)
704716
return _State(t=t, hidden=rv, aux_extra=extra, output_scale=calib_state)
705717

706718
def step(self, state: _State, *, vector_field, dt):
@@ -766,11 +778,6 @@ def _state(t_, s, scale):
766778
acc = _state(t, step_from_, interp_to.output_scale)
767779
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)
768780

769-
def initial_condition(self):
770-
"""Construct an initial condition."""
771-
posterior = self.strategy.initial_condition(prior=self.prior)
772-
return posterior, self.prior.output_scale
773-
774781

775782
def solver_mle(strategy, *, correction, prior, ssm):
776783
"""Create a solver that calibrates the output scale via maximum-likelihood.

0 commit comments

Comments
 (0)