Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 53 additions & 26 deletions probdiffeq/ivpsolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,19 @@
tree_util,
)
from probdiffeq.backend import numpy as np
from probdiffeq.backend.typing import Any, Array, Callable, Generic, NamedArg, TypeVar
from probdiffeq.backend.typing import (
Any,
ArrayLike,
Callable,
Generic,
NamedArg,
TypeVar,
)
from probdiffeq.impl import impl


def prior_wiener_integrated(
tcoeffs, *, ssm_fact: str, output_scale: Array | None = None, damp: float = 0.0
tcoeffs, *, ssm_fact: str, output_scale: ArrayLike | None = None, damp: float = 0.0
):
"""Construct an adaptive(/continuous-time), multiply-integrated Wiener process."""
ssm = impl.choose(ssm_fact, tcoeffs_like=tcoeffs)
Expand Down Expand Up @@ -78,8 +85,8 @@ class _InterpRes(Generic[R]):
class _PositiveCubatureRule(containers.NamedTuple):
"""Cubature rule with positive weights."""

points: Array
weights_sqrtm: Array
points: ArrayLike
weights_sqrtm: ArrayLike


def cubature_third_order_spherical(input_shape) -> _PositiveCubatureRule:
Expand Down Expand Up @@ -538,6 +545,13 @@ class _State(containers.NamedTuple):
output_scale: Any


@tree_util.register_dataclass
@containers.dataclass
class _ErrorEstimate:
estimate: ArrayLike
reference: ArrayLike


@containers.dataclass
class _ProbabilisticSolver:
name: str
Expand Down Expand Up @@ -671,6 +685,8 @@ def solver_mle(strategy, *, correction, prior, ssm):
"""

def step_mle(state, /, *, dt, calibration):
u_step_from = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0]

# Estimate the error
output_scale_prior, _calibrated = calibration.extract(state.output_scale)
transition = prior(dt, output_scale_prior)
Expand All @@ -689,8 +705,13 @@ def step_mle(state, /, *, dt, calibration):

# Calibrate the output scale
output_scale = calibration.update(state.output_scale, observed=observed)

# Normalise the error
state = _State(t=t, rv=hidden, strategy_state=extra, output_scale=output_scale)
return dt * error, state
u_proposed = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0]
reference = np.maximum(np.abs(u_proposed), np.abs(u_step_from))
error = _ErrorEstimate(dt * error, reference=reference)
return error, state

return _ProbabilisticSolver(
ssm=ssm,
Expand Down Expand Up @@ -726,6 +747,8 @@ def solver_dynamic(strategy, *, correction, prior, ssm):
"""Create a solver that calibrates the output scale dynamically."""

def step_dynamic(state, /, *, dt, calibration):
u_step_from = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0]

# Estimate error and calibrate the output scale
ones = np.ones_like(ssm.prototypes.output_scale())
transition = prior(dt, ones)
Expand All @@ -748,7 +771,12 @@ def step_dynamic(state, /, *, dt, calibration):

# Return solution
state = _State(t=t, rv=hidden, strategy_state=extra, output_scale=output_scale)
return dt * error, state

# Normalise the error
u_proposed = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0]
reference = np.maximum(np.abs(u_proposed), np.abs(u_step_from))
error = _ErrorEstimate(dt * error, reference=reference)
return error, state

return _ProbabilisticSolver(
prior=prior,
Expand Down Expand Up @@ -780,6 +808,8 @@ def solver(strategy, *, correction, prior, ssm):
def step(state: _State, *, dt, calibration):
del calibration # unused

u_step_from = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0]

# Estimate the error
transition = prior(dt, state.output_scale)
mean = ssm.stats.mean(state.rv)
Expand All @@ -794,12 +824,15 @@ def step(state: _State, *, dt, calibration):

# Do the full correction step
hidden, corr = correction.correct(hidden, t=t)

# Extract and return solution
state = _State(
t=t, rv=hidden, strategy_state=extra, output_scale=state.output_scale
)
return dt * error, state

# Normalise the error
u_proposed = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0]
reference = np.maximum(np.abs(u_proposed), np.abs(u_step_from))
error = _ErrorEstimate(dt * error, reference=reference)
return error, state

return _ProbabilisticSolver(
ssm=ssm,
Expand Down Expand Up @@ -960,18 +993,9 @@ def body_fn(state: _RejectionState) -> _RejectionState:
error_estimate, state_proposed = self.solver.step(
state=state.step_from, dt=dt
)
# Normalise the error
u_proposed = tree_util.ravel_pytree(
self.ssm.unravel(state_proposed.rv.mean)[0]
)[0]
u_step_from = tree_util.ravel_pytree(
self.ssm.unravel(state.step_from.rv.mean)[0]
)[0]

u = np.maximum(np.abs(u_proposed), np.abs(u_step_from))
error_power = _error_scale_and_normalize(error_estimate, u=u)

# Propose a new step
error_power = self._error_scale_and_normalize(error_estimate)
dt, state_control = self.control.apply(
dt, state.control, error_power=error_power
)
Expand All @@ -983,13 +1007,6 @@ def body_fn(state: _RejectionState) -> _RejectionState:
step_from=state.step_from,
)

def _error_scale_and_normalize(error_estimate, u):
error_relative = error_estimate / (self.atol + self.rtol * np.abs(u))
dim = np.atleast_1d(u).size
error_norm = linalg.vector_norm(error_relative, order=self.norm_ord)
error_norm_rel = error_norm / np.sqrt(dim)
return error_norm_rel ** (-1.0 / self.solver.error_contraction_rate)

def extract(s: _RejectionState) -> _AdaState:
num_steps = state0.stats + 1.0 # TODO: track step attempts as well
return _AdaState(
Expand All @@ -1004,6 +1021,16 @@ def extract(s: _RejectionState) -> _AdaState:
state_new = control_flow.while_loop(cond_fn, body_fn, init_val)
return extract(state_new)

def _error_scale_and_normalize(self, error: _ErrorEstimate):
assert isinstance(error, _ErrorEstimate)
normalize = self.atol + self.rtol * np.abs(error.reference)
error_relative = error.estimate / normalize

dim = np.atleast_1d(error.reference).size
error_norm = linalg.vector_norm(error_relative, order=self.norm_ord)
error_norm_rel = error_norm / np.sqrt(dim)
return error_norm_rel ** (-1.0 / self.solver.error_contraction_rate)

def extract_before_t1(self, state: _AdaState, t):
del t
solution_solver = self.solver.extract(state.step_from)
Expand Down