diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 5644a0458..56aae40b6 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -10,12 +10,19 @@ tree_util, ) from probdiffeq.backend import numpy as np -from probdiffeq.backend.typing import Any, Array, Callable, Generic, NamedArg, TypeVar +from probdiffeq.backend.typing import ( + Any, + ArrayLike, + Callable, + Generic, + NamedArg, + TypeVar, +) from probdiffeq.impl import impl def prior_wiener_integrated( - tcoeffs, *, ssm_fact: str, output_scale: Array | None = None, damp: float = 0.0 + tcoeffs, *, ssm_fact: str, output_scale: ArrayLike | None = None, damp: float = 0.0 ): """Construct an adaptive(/continuous-time), multiply-integrated Wiener process.""" ssm = impl.choose(ssm_fact, tcoeffs_like=tcoeffs) @@ -78,8 +85,8 @@ class _InterpRes(Generic[R]): class _PositiveCubatureRule(containers.NamedTuple): """Cubature rule with positive weights.""" - points: Array - weights_sqrtm: Array + points: ArrayLike + weights_sqrtm: ArrayLike def cubature_third_order_spherical(input_shape) -> _PositiveCubatureRule: @@ -538,6 +545,13 @@ class _State(containers.NamedTuple): output_scale: Any +@tree_util.register_dataclass +@containers.dataclass +class _ErrorEstimate: + estimate: ArrayLike + reference: ArrayLike + + @containers.dataclass class _ProbabilisticSolver: name: str @@ -671,6 +685,8 @@ def solver_mle(strategy, *, correction, prior, ssm): """ def step_mle(state, /, *, dt, calibration): + u_step_from = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0] + # Estimate the error output_scale_prior, _calibrated = calibration.extract(state.output_scale) transition = prior(dt, output_scale_prior) @@ -689,8 +705,13 @@ def step_mle(state, /, *, dt, calibration): # Calibrate the output scale output_scale = calibration.update(state.output_scale, observed=observed) + + # Normalise the error state = _State(t=t, rv=hidden, strategy_state=extra, output_scale=output_scale) - return dt * error, state + u_proposed = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0] + reference = np.maximum(np.abs(u_proposed), np.abs(u_step_from)) + error = _ErrorEstimate(dt * error, reference=reference) + return error, state return _ProbabilisticSolver( ssm=ssm, @@ -726,6 +747,8 @@ def solver_dynamic(strategy, *, correction, prior, ssm): """Create a solver that calibrates the output scale dynamically.""" def step_dynamic(state, /, *, dt, calibration): + u_step_from = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0] + # Estimate error and calibrate the output scale ones = np.ones_like(ssm.prototypes.output_scale()) transition = prior(dt, ones) @@ -748,7 +771,12 @@ def step_dynamic(state, /, *, dt, calibration): # Return solution state = _State(t=t, rv=hidden, strategy_state=extra, output_scale=output_scale) - return dt * error, state + + # Normalise the error + u_proposed = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0] + reference = np.maximum(np.abs(u_proposed), np.abs(u_step_from)) + error = _ErrorEstimate(dt * error, reference=reference) + return error, state return _ProbabilisticSolver( prior=prior, @@ -780,6 +808,8 @@ def solver(strategy, *, correction, prior, ssm): def step(state: _State, *, dt, calibration): del calibration # unused + u_step_from = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0] + # Estimate the error transition = prior(dt, state.output_scale) mean = ssm.stats.mean(state.rv) @@ -794,12 +824,15 @@ def step(state: _State, *, dt, calibration): # Do the full correction step hidden, corr = correction.correct(hidden, t=t) - - # Extract and return solution state = _State( t=t, rv=hidden, strategy_state=extra, output_scale=state.output_scale ) - return dt * error, state + + # Normalise the error + u_proposed = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0] + reference = np.maximum(np.abs(u_proposed), np.abs(u_step_from)) + error = _ErrorEstimate(dt * error, reference=reference) + return error, state return _ProbabilisticSolver( ssm=ssm, @@ -960,18 +993,9 @@ def body_fn(state: _RejectionState) -> _RejectionState: error_estimate, state_proposed = self.solver.step( state=state.step_from, dt=dt ) - # Normalise the error - u_proposed = tree_util.ravel_pytree( - self.ssm.unravel(state_proposed.rv.mean)[0] - )[0] - u_step_from = tree_util.ravel_pytree( - self.ssm.unravel(state.step_from.rv.mean)[0] - )[0] - - u = np.maximum(np.abs(u_proposed), np.abs(u_step_from)) - error_power = _error_scale_and_normalize(error_estimate, u=u) # Propose a new step + error_power = self._error_scale_and_normalize(error_estimate) dt, state_control = self.control.apply( dt, state.control, error_power=error_power ) @@ -983,13 +1007,6 @@ def body_fn(state: _RejectionState) -> _RejectionState: step_from=state.step_from, ) - def _error_scale_and_normalize(error_estimate, u): - error_relative = error_estimate / (self.atol + self.rtol * np.abs(u)) - dim = np.atleast_1d(u).size - error_norm = linalg.vector_norm(error_relative, order=self.norm_ord) - error_norm_rel = error_norm / np.sqrt(dim) - return error_norm_rel ** (-1.0 / self.solver.error_contraction_rate) - def extract(s: _RejectionState) -> _AdaState: num_steps = state0.stats + 1.0 # TODO: track step attempts as well return _AdaState( @@ -1004,6 +1021,16 @@ def extract(s: _RejectionState) -> _AdaState: state_new = control_flow.while_loop(cond_fn, body_fn, init_val) return extract(state_new) + def _error_scale_and_normalize(self, error: _ErrorEstimate): + assert isinstance(error, _ErrorEstimate) + normalize = self.atol + self.rtol * np.abs(error.reference) + error_relative = error.estimate / normalize + + dim = np.atleast_1d(error.reference).size + error_norm = linalg.vector_norm(error_relative, order=self.norm_ord) + error_norm_rel = error_norm / np.sqrt(dim) + return error_norm_rel ** (-1.0 / self.solver.error_contraction_rate) + def extract_before_t1(self, state: _AdaState, t): del t solution_solver = self.solver.extract(state.step_from)