|
1 | 1 | """Probabilistic IVP solvers.""" |
2 | 2 |
|
3 | | -from probdiffeq import stats |
| 3 | +from probdiffeq import ivpsolve, stats |
4 | 4 | from probdiffeq.backend import ( |
5 | 5 | containers, |
6 | 6 | control_flow, |
@@ -694,13 +694,25 @@ def is_suitable_for_save_at(self): |
694 | 694 | def is_suitable_for_save_every_step(self): |
695 | 695 | return self.strategy.is_suitable_for_save_every_step |
696 | 696 |
|
697 | | - def init(self, t, initial_condition) -> _State: |
698 | | - posterior, output_scale = initial_condition |
| 697 | + def initial_condition(self) -> ivpsolve.IVPSolution: |
| 698 | + """Construct an initial condition.""" |
| 699 | + posterior = self.strategy.initial_condition(prior=self.prior) |
| 700 | + return ivpsolve.IVPSolution( |
| 701 | + t=None, |
| 702 | + u=None, |
| 703 | + u_std=None, |
| 704 | + output_scale=self.prior.output_scale, |
| 705 | + marginals=None, |
| 706 | + posterior=posterior, |
| 707 | + num_steps=None, |
| 708 | + ssm=None, |
| 709 | + ) |
699 | 710 |
|
700 | | - rv, extra = self.strategy.init(posterior) |
| 711 | + def init(self, t, initial_condition: ivpsolve.IVPSolution) -> _State: |
| 712 | + rv, extra = self.strategy.init(initial_condition.posterior) |
701 | 713 | rv, corr = self.correction.init(rv) |
702 | 714 |
|
703 | | - calib_state = self.calibration.init(output_scale) |
| 715 | + calib_state = self.calibration.init(initial_condition.output_scale) |
704 | 716 | return _State(t=t, hidden=rv, aux_extra=extra, output_scale=calib_state) |
705 | 717 |
|
706 | 718 | def step(self, state: _State, *, vector_field, dt): |
@@ -766,11 +778,6 @@ def _state(t_, s, scale): |
766 | 778 | acc = _state(t, step_from_, interp_to.output_scale) |
767 | 779 | return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) |
768 | 780 |
|
769 | | - def initial_condition(self): |
770 | | - """Construct an initial condition.""" |
771 | | - posterior = self.strategy.initial_condition(prior=self.prior) |
772 | | - return posterior, self.prior.output_scale |
773 | | - |
774 | 781 |
|
775 | 782 | def solver_mle(strategy, *, correction, prior, ssm): |
776 | 783 | """Create a solver that calibrates the output scale via maximum-likelihood. |
|
0 commit comments