diff --git a/probdiffeq/ivpsolve.py b/probdiffeq/ivpsolve.py index 7c413f644..72dfa2e49 100644 --- a/probdiffeq/ivpsolve.py +++ b/probdiffeq/ivpsolve.py @@ -4,7 +4,6 @@ from probdiffeq.backend import ( containers, control_flow, - functools, linalg, tree_array_util, tree_util, @@ -89,42 +88,21 @@ def solve_adaptive_terminal_values( vector_field, initial_condition, t0, t1, adaptive_solver, dt0, *, ssm ) -> IVPSolution: """Simulate the terminal values of an initial value problem.""" - save_at = np.asarray([t1]) - (_t, solution_save_at), _, num_steps = _solve_adaptive_save_at( - tree_util.Partial(vector_field), - t0, + save_at = np.asarray([t0, t1]) + solution = solve_adaptive_save_at( + vector_field, initial_condition, save_at=save_at, adaptive_solver=adaptive_solver, dt0=dt0, - ) - # "squeeze"-type functionality (there is only a single state!) - squeeze_fun = functools.partial(np.squeeze_along_axis, axis=0) - solution_save_at = tree_util.tree_map(squeeze_fun, solution_save_at) - num_steps = tree_util.tree_map(squeeze_fun, num_steps) - - # I think the user expects marginals, so we compute them here - # todo: do this in IVPSolution.* methods? - posterior, output_scale = solution_save_at - marginals = posterior.init if isinstance(posterior, stats.MarkovSeq) else posterior - - u = ssm.stats.qoi_from_sample(marginals.mean) - std = ssm.stats.standard_deviation(marginals) - u_std = ssm.stats.qoi_from_sample(std) - return IVPSolution( - t=t1, - u=u, - u_std=u_std, ssm=ssm, - marginals=marginals, - posterior=posterior, - output_scale=output_scale, - num_steps=num_steps, + warn=False, # Turn off warnings because any solver goes for terminal values ) + return tree_util.tree_map(lambda s: s[-1], solution) def solve_adaptive_save_at( - vector_field, initial_condition, save_at, adaptive_solver, dt0, *, ssm + vector_field, initial_condition, save_at, adaptive_solver, dt0, *, ssm, warn=True ) -> IVPSolution: r"""Solve an initial value problem and return the solution at a pre-determined grid. @@ -152,7 +130,7 @@ def solve_adaptive_save_at( } ``` """ - if not adaptive_solver.solver.is_suitable_for_save_at: + if not adaptive_solver.solver.is_suitable_for_save_at and warn: msg = ( f"Strategy {adaptive_solver.solver} should not " f"be used in solve_adaptive_save_at. " @@ -170,7 +148,7 @@ def solve_adaptive_save_at( # I think the user expects the initial condition to be part of the state # (as well as marginals), so we compute those things here - posterior_t0, *_ = initial_condition + posterior_t0 = initial_condition.posterior posterior_save_at, output_scale = solution_save_at _tmp = _userfriendly_output( posterior=posterior_save_at, posterior_t0=posterior_t0, ssm=ssm @@ -194,41 +172,37 @@ def solve_adaptive_save_at( def _solve_adaptive_save_at( vector_field, t, initial_condition, *, save_at, adaptive_solver, dt0 ): - advance_func = functools.partial( - _advance_and_interpolate, - vector_field=vector_field, - adaptive_solver=adaptive_solver, - ) - - state = adaptive_solver.init(t, initial_condition, dt=dt0, num_steps=0.0) - _, solution = control_flow.scan(advance_func, init=state, xs=save_at, reverse=False) - return solution - - -def _advance_and_interpolate(state, t_next, *, vector_field, adaptive_solver): - # Advance until accepted.t >= t_next. - # Note: This could already be the case and we may not loop (just interpolate) - def cond_fun(s): - # Terminate the loop if - # the difference from s.t to t_next is smaller than a constant factor - # (which is a "small" multiple of the current machine precision) - # or if s.t > t_next holds. - return s.step_from.t + 10 * np.finfo_eps(float) < t_next + def advance(state, t_next): + # Advance until accepted.t >= t_next. + # Note: This could already be the case and we may not loop (just interpolate) + def cond_fun(s): + # Terminate the loop if + # the difference from s.t to t_next is smaller than a constant factor + # (which is a "small" multiple of the current machine precision) + # or if s.t > t_next holds. + return s.step_from.t + adaptive_solver.eps < t_next + + def body_fun(s): + return adaptive_solver.rejection_loop( + s, vector_field=vector_field, t1=t_next + ) - def body_fun(s): - return adaptive_solver.rejection_loop(s, vector_field=vector_field, t1=t_next) + state = control_flow.while_loop(cond_fun, body_fun, init=state) - state = control_flow.while_loop(cond_fun, body_fun, init=state) + # Either interpolate (t > t_next) or "finalise" (t == t_next) + is_after_t1 = state.step_from.t > t_next + adaptive_solver.eps + state, solution = control_flow.cond( + is_after_t1, + adaptive_solver.extract_after_t1, + adaptive_solver.extract_at_t1, + state, + t_next, + ) + return state, solution - # Either interpolate (t > t_next) or "finalise" (t == t_next) - state, solution = control_flow.cond( - state.step_from.t > t_next + 10 * np.finfo_eps(float), - adaptive_solver.extract_after_t1_via_interpolation, - lambda s, _t: adaptive_solver.extract_at_t1(s), - state, - t_next, - ) - return state, solution + state = adaptive_solver.init(t, initial_condition, dt=dt0, num_steps=0.0) + _, solution = control_flow.scan(advance, init=state, xs=save_at, reverse=False) + return solution def solve_adaptive_save_every_step( @@ -264,7 +238,7 @@ def solve_adaptive_save_every_step( t = np.concatenate((np.atleast_1d(t0), t)) # I think the user expects marginals, so we compute them here - posterior_t0, *_ = initial_condition + posterior_t0 = initial_condition.posterior posterior, output_scale = solution_every_step _tmp = _userfriendly_output(posterior=posterior, posterior_t0=posterior_t0, ssm=ssm) marginals, posterior = _tmp @@ -292,15 +266,16 @@ def _solution_generator( while state.step_from.t < t1: state = adaptive_solver.rejection_loop(state, vector_field=vector_field, t1=t1) - if state.step_from.t < t1: - solution = adaptive_solver.extract_before_t1(state) + if state.step_from.t + adaptive_solver.eps < t1: + _, solution = adaptive_solver.extract_before_t1(state, t=t1) yield solution # Either interpolate (t > t_next) or "finalise" (t == t_next) - if state.step_from.t > t1: - _, solution = adaptive_solver.extract_after_t1_via_interpolation(state, t=t1) + is_after_t1 = state.step_from.t > t1 + adaptive_solver.eps + if is_after_t1: + _, solution = adaptive_solver.extract_after_t1(state, t=t1) else: - _, solution = adaptive_solver.extract_at_t1(state) + _, solution = adaptive_solver.extract_at_t1(state, t=t1) yield solution @@ -321,7 +296,7 @@ def body_fn(s, dt): _t, (posterior, output_scale) = solver.extract(result_state) # I think the user expects marginals, so we compute them here - posterior_t0, *_ = initial_condition + posterior_t0 = initial_condition.posterior _tmp = _userfriendly_output(posterior=posterior, posterior_t0=posterior_t0, ssm=ssm) marginals, posterior = _tmp diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 8736d14fc..8cc6e231f 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -1,6 +1,6 @@ """Probabilistic IVP solvers.""" -from probdiffeq import stats +from probdiffeq import ivpsolve, stats from probdiffeq.backend import ( containers, control_flow, @@ -197,11 +197,11 @@ def init(self, sol: stats.MarkovSeq, /): raise NotImplementedError def begin(self, rv, _extra, /, *, prior_discretized): - """Begin the extrapolation.""" + """Begin the strategy.""" raise NotImplementedError def complete(self, _ssv, extra, /, output_scale): - """Complete the extrapolation.""" + """Complete the strategy.""" raise NotImplementedError def extract(self, hidden_state, extra, /): @@ -428,6 +428,10 @@ def complete(self, _rv, extra, /, output_scale): # Gather and return return extrapolated, cond + def interpolate_at_t1(self, rv, extra, /, *, prior): + cond_identity = self.ssm.conditional.identity(prior.num_derivatives + 1) + return _InterpRes((rv, cond_identity), (rv, extra), (rv, cond_identity)) + def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale, prior): """Interpolate. @@ -497,10 +501,6 @@ def _extrapolate(self, state, extra, /, *, output_scale, prior_discretized): x, cache = self.begin(state, extra, prior_discretized=prior_discretized) return self.complete(x, cache, output_scale=output_scale) - def interpolate_at_t1(self, rv, extra, /, *, prior): - cond_identity = self.ssm.conditional.identity(prior.num_derivatives + 1) - return _InterpRes((rv, cond_identity), (rv, extra), (rv, cond_identity)) - return FixedPoint( name="Fixed-point smoother", ssm=ssm, @@ -642,7 +642,7 @@ class _ProbabilisticSolver: prior: _MarkovProcess ssm: Any - extrapolation: _Strategy + strategy: _Strategy calibration: _Calibration correction: _Correction @@ -654,11 +654,11 @@ def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_sca dt0 = t - t0 dt1 = t1 - t - rv, extra = self.extrapolation.init(posterior_t0) + rv, extra = self.strategy.init(posterior_t0) rv, corr = self.correction.init(rv) # TODO: Replace dt0, dt1, and prior with prior_dt0, and prior_dt1 - interp = self.extrapolation.interpolate( + interp = self.strategy.interpolate( state_t0=(rv, extra), marginal_t1=marginals_t1, dt0=dt0, @@ -677,23 +677,35 @@ def error_contraction_rate(self): @property def is_suitable_for_offgrid_marginals(self): - return self.extrapolation.is_suitable_for_offgrid_marginals + return self.strategy.is_suitable_for_offgrid_marginals @property def is_suitable_for_save_at(self): - return self.extrapolation.is_suitable_for_save_at + return self.strategy.is_suitable_for_save_at @property def is_suitable_for_save_every_step(self): - return self.extrapolation.is_suitable_for_save_every_step + return self.strategy.is_suitable_for_save_every_step - def init(self, t, initial_condition) -> _State: - posterior, output_scale = initial_condition + def initial_condition(self) -> ivpsolve.IVPSolution: + """Construct an initial condition.""" + posterior = self.strategy.initial_condition(prior=self.prior) + return ivpsolve.IVPSolution( + t=None, + u=None, + u_std=None, + output_scale=self.prior.output_scale, + marginals=None, + posterior=posterior, + num_steps=None, + ssm=None, + ) - rv, extra = self.extrapolation.init(posterior) + def init(self, t, initial_condition: ivpsolve.IVPSolution) -> _State: + rv, extra = self.strategy.init(initial_condition.posterior) rv, corr = self.correction.init(rv) - calib_state = self.calibration.init(output_scale) + calib_state = self.calibration.init(initial_condition.output_scale) return _State(t=t, hidden=rv, aux_extra=extra, output_scale=calib_state) def step(self, state: _State, *, vector_field, dt): @@ -703,7 +715,7 @@ def step(self, state: _State, *, vector_field, dt): def extract(self, state: _State, /): hidden = state.hidden - posterior = self.extrapolation.extract(hidden, state.aux_extra) + posterior = self.strategy.extract(hidden, state.aux_extra) t = state.t _output_scale_prior, output_scale = self.calibration.extract(state.output_scale) @@ -718,7 +730,7 @@ def interpolate(self, t, *, interp_from: _State, interp_to: _State) -> _InterpRe def _case_interpolate(self, t, *, s0, s1, output_scale) -> _InterpRes: """Process the solution in case t>t_n.""" # Interpolate - interp = self.extrapolation.interpolate( + interp = self.strategy.interpolate( state_t0=(s0.hidden, s0.aux_extra), marginal_t1=s1.hidden, dt0=t - s0.t, @@ -741,7 +753,7 @@ def _state(t_, x, scale): def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: """Process the solution in case t=t_n.""" - tmp = self.extrapolation.interpolate_at_t1( + tmp = self.strategy.interpolate_at_t1( interp_to.hidden, interp_to.aux_extra, prior=self.prior ) step_from_, solution_, interp_from_ = ( @@ -759,11 +771,6 @@ def _state(t_, s, scale): acc = _state(t, step_from_, interp_to.output_scale) return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) - def initial_condition(self): - """Construct an initial condition.""" - posterior = self.extrapolation.initial_condition(prior=self.prior) - return posterior, self.prior.output_scale - def solver_mle(strategy, *, correction, prior, ssm): """Create a solver that calibrates the output scale via maximum-likelihood. @@ -799,7 +806,7 @@ def step_mle(state, /, *, dt, vector_field, calibration): prior=prior, calibration=_calibration_running_mean(ssm=ssm), step_implementation=step_mle, - extrapolation=strategy, + strategy=strategy, correction=correction, requires_rescaling=True, ) @@ -854,7 +861,7 @@ def step_dynamic(state, /, *, dt, vector_field, calibration): return _ProbabilisticSolver( prior=prior, ssm=ssm, - extrapolation=strategy, + strategy=strategy, correction=correction, calibration=_calibration_most_recent(ssm=ssm), name="Dynamic probabilistic solver", @@ -905,7 +912,7 @@ def step(state: _State, *, vector_field, dt, calibration): return _ProbabilisticSolver( ssm=ssm, prior=prior, - extrapolation=strategy, + strategy=strategy, correction=correction, calibration=_calibration_none(), step_implementation=step, @@ -937,11 +944,13 @@ def adaptive( control=None, norm_ord=None, clip_dt: bool = False, + eps: float | None = None, ): """Make an IVP solver adaptive.""" if control is None: control = control_proportional_integral() - + if eps is None: + eps = 10 * np.finfo_eps(float) return _AdaSolver( slvr, ssm=ssm, @@ -950,6 +959,7 @@ def adaptive( control=control, norm_ord=norm_ord, clip_dt=clip_dt, + eps=eps, ) @@ -975,6 +985,7 @@ def __init__( norm_ord, ssm, clip_dt: bool, + eps: float, ): self.solver = slvr self.atol = atol @@ -983,6 +994,7 @@ def __init__( self.norm_ord = norm_ord self.ssm = ssm self.clip_dt = clip_dt + self.eps = eps def __repr__(self): return ( @@ -1000,7 +1012,13 @@ def init(self, t, initial_condition, dt, num_steps) -> _AdaState: """Initialise the IVP solver state.""" state_solver = self.solver.init(t, initial_condition) state_control = self.control.init(dt) - return _AdaState(dt, state_solver, state_solver, state_control, num_steps) + return _AdaState( + dt=dt, + step_from=state_solver, + interp_from=state_solver, + control=state_control, + stats=num_steps, + ) @functools.jit def rejection_loop(self, state0: _AdaState, *, vector_field, t1) -> _AdaState: @@ -1021,7 +1039,7 @@ def init(s0: _AdaState) -> _RejectionState: def _inf_like(tree): return tree_util.tree_map(lambda x: np.inf() * np.ones_like(x), tree) - smaller_than_1 = 1.0 / 1.1 # the cond() must return True + smaller_than_1 = 0.9 # the cond() must return True return _RejectionState( error_norm_proposed=smaller_than_1, dt=s0.dt, @@ -1078,35 +1096,46 @@ def _error_scale_and_normalize(error_estimate, *, u): def extract(s: _RejectionState) -> _AdaState: num_steps = state0.stats + 1.0 # TODO: track step attempts as well - return _AdaState(s.dt, s.proposed, s.step_from, s.control, num_steps) + return _AdaState( + dt=s.dt, + step_from=s.proposed, + interp_from=s.step_from, + control=s.control, + stats=num_steps, + ) init_val = init(state0) state_new = control_flow.while_loop(cond_fn, body_fn, init_val) return extract(state_new) - def extract_before_t1(self, state: _AdaState): + def extract_before_t1(self, state: _AdaState, t): + del t solution_solver = self.solver.extract(state.step_from) - return solution_solver, (state.dt, state.control), state.stats + extracted = solution_solver, (state.dt, state.control), state.stats + return state, extracted - def extract_at_t1(self, state: _AdaState): + def extract_at_t1(self, state: _AdaState, t): + del t # todo: make the "at t1" decision inside interpolate(), # which collapses the next two functions together interp = self.solver.interpolate_at_t1( interp_from=state.interp_from, interp_to=state.step_from ) - state = _AdaState( - state.dt, interp.step_from, interp.interp_from, state.control, state.stats - ) + return self._extract_interpolate(interp, state) - solution_solver = self.solver.extract(interp.interpolated) - return state, (solution_solver, (state.dt, state.control), state.stats) - - def extract_after_t1_via_interpolation(self, state: _AdaState, t): + def extract_after_t1(self, state: _AdaState, t): interp = self.solver.interpolate( t, interp_from=state.interp_from, interp_to=state.step_from ) + return self._extract_interpolate(interp, state) + + def _extract_interpolate(self, interp, state): state = _AdaState( - state.dt, interp.step_from, interp.interp_from, state.control, state.stats + dt=state.dt, + step_from=interp.step_from, + interp_from=interp.interp_from, + control=state.control, + stats=state.stats, ) solution_solver = self.solver.extract(interp.interpolated) @@ -1115,7 +1144,7 @@ def extract_after_t1_via_interpolation(self, state: _AdaState, t): @staticmethod def register_pytree_node(): def _asolver_flatten(asolver): - children = (asolver.atol, asolver.rtol) + children = (asolver.atol, asolver.rtol, asolver.eps) aux = ( asolver.solver, asolver.control, @@ -1126,7 +1155,7 @@ def _asolver_flatten(asolver): return children, aux def _asolver_unflatten(aux, children): - atol, rtol = children + atol, rtol, eps = children (slvr, control, norm_ord, ssm, clip_dt) = aux return _AdaSolver( slvr, @@ -1136,6 +1165,7 @@ def _asolver_unflatten(aux, children): norm_ord=norm_ord, ssm=ssm, clip_dt=clip_dt, + eps=eps, ) tree_util.register_pytree_node( @@ -1173,7 +1203,7 @@ def init(_dt: float, /) -> float: return 1.0 def apply(dt: float, error_power_prev: float, /, *, error_power): - # error_power = error_norm ** (-1.0 / error_contraction_rate) + # Equivalent: error_power = error_norm ** (-1.0 / error_contraction_rate) a1 = error_power**power_integral_unscaled a2 = (error_power / error_power_prev) ** power_proportional_unscaled scale_factor_unclipped = safety * a1 * a2