diff --git a/docs/examples_basic/posterior_uncertainties.py b/docs/examples_basic/posterior_uncertainties.py index c1341bd69..3049f13e6 100644 --- a/docs/examples_basic/posterior_uncertainties.py +++ b/docs/examples_basic/posterior_uncertainties.py @@ -42,7 +42,7 @@ def vf(y, *, t): # noqa: ARG001 # Set up a solver # To all users: Try replacing the fixedpoint-smoother with a filter! tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=3) -init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense") +init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="blockdiag") ts = ivpsolvers.correction_ts1(vf, ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts, ssm=ssm) @@ -60,13 +60,20 @@ def vf(y, *, t): # noqa: ARG001 u_std = ssm.stats.qoi_from_sample(std) # Plot the solution -fig, axes = plt.subplots(nrows=2, ncols=len(tcoeffs), tight_layout=True, figsize=(8, 3)) +fig, axes = plt.subplots( + nrows=3, + ncols=len(tcoeffs), + sharex="col", + tight_layout=True, + figsize=(len(u_std) * 2, 5), +) for i, (u_i, std_i, ax_i) in enumerate(zip(sol.u, u_std, axes.T)): # Set up titles and axis descriptions if i == 0: ax_i[0].set_title("State") - ax_i[0].set_ylabel("Predators") - ax_i[1].set_ylabel("Prey") + ax_i[0].set_ylabel("Prey") + ax_i[1].set_ylabel("Predators") + ax_i[2].set_ylabel("Std.-dev.") elif i == 1: ax_i[0].set_title(f"{i}st deriv.") elif i == 2: @@ -76,7 +83,7 @@ def vf(y, *, t): # noqa: ARG001 else: ax_i[0].set_title(f"{i}th deriv.") - ax_i[1].set_xlabel("Time") + ax_i[-1].set_xlabel("Time") for m, std, ax in zip(u_i.T, std_i.T, ax_i): # Plot the mean @@ -87,5 +94,9 @@ def vf(y, *, t): # noqa: ARG001 ax.fill_between(sol.t, lower, upper, alpha=0.3) ax.set_xlim((jnp.amin(ts), jnp.amax(ts))) + ax_i[2].semilogy(sol.t, std_i[:, 0], label="Prey") + ax_i[2].semilogy(sol.t, std_i[:, 1], label="Predators") + ax_i[2].legend(fontsize="x-small") + fig.align_ylabels() plt.show() diff --git a/probdiffeq/backend/linalg.py b/probdiffeq/backend/linalg.py index f501eccde..805285d53 100644 --- a/probdiffeq/backend/linalg.py +++ b/probdiffeq/backend/linalg.py @@ -67,6 +67,10 @@ def diagonal(arr, /): return jnp.diagonal(arr) +def trace(arr, /): + return jnp.trace(arr) + + def diagonal_matrix(arr, /): return jnp.diag(arr) diff --git a/probdiffeq/backend/random.py b/probdiffeq/backend/random.py index 6cca5789a..eb6ce7300 100644 --- a/probdiffeq/backend/random.py +++ b/probdiffeq/backend/random.py @@ -7,5 +7,13 @@ def prng_key(*, seed): return jax.random.PRNGKey(seed=seed) +def split(key, num): + return jax.random.split(key, num=num) + + def normal(key, /, shape): return jax.random.normal(key, shape=shape) + + +def rademacher(key, /, shape, dtype): + return jax.random.rademacher(key, shape=shape, dtype=dtype) diff --git a/probdiffeq/impl/_linearise.py b/probdiffeq/impl/_linearise.py index 781d18132..a6bd407c0 100644 --- a/probdiffeq/impl/_linearise.py +++ b/probdiffeq/impl/_linearise.py @@ -1,29 +1,37 @@ -from probdiffeq.backend import abc, functools, tree_util +from probdiffeq.backend import abc, containers, functools, linalg, random, tree_util from probdiffeq.backend import numpy as np from probdiffeq.backend.typing import Callable from probdiffeq.impl import _conditional, _normal from probdiffeq.util import cholesky_util +@containers.dataclass +class _Linearization: + """Linearisation API.""" + + init: Callable + update: Callable + + class LinearisationBackend(abc.ABC): @abc.abstractmethod - def ode_taylor_0th(self, ode_order: int, damp: float) -> _normal.Normal: + def ode_taylor_0th(self, ode_order: int, damp: float) -> _Linearization: raise NotImplementedError @abc.abstractmethod - def ode_taylor_1st(self, ode_order: int, damp: float) -> _normal.Normal: + def ode_taylor_1st(self, ode_order: int, damp: float) -> _Linearization: raise NotImplementedError @abc.abstractmethod def ode_statistical_1st( self, cubature_fun: Callable, damp: float - ) -> _normal.Normal: + ) -> _Linearization: raise NotImplementedError @abc.abstractmethod def ode_statistical_0th( self, cubature_fun: Callable, damp: float - ) -> _normal.Normal: + ) -> _Linearization: raise NotImplementedError @@ -32,30 +40,45 @@ def __init__(self, ode_shape, unravel): self.ode_shape = ode_shape self.unravel = unravel - def ode_taylor_0th(self, ode_order, damp: float): - def linearise_fun_wrapped(fun, rv): - mean = rv.mean + def ode_taylor_0th(self, ode_order, damp: float) -> _Linearization: + def init(): + return None + + def step(fun, rv, state): + del state def a1(m): + """Select the 'n'-th derivative.""" return tree_util.ravel_pytree(self.unravel(m)[ode_order])[0] - fx = tree_util.ravel_pytree(fun(*self.unravel(mean)[:ode_order]))[0] - linop = functools.jacrev(a1)(mean) - + fx = tree_util.ravel_pytree(fun(*self.unravel(rv.mean)[:ode_order]))[0] + linop = functools.jacrev(a1)(rv.mean) cov_lower = damp * np.eye(len(fx)) bias = _normal.Normal(-fx, cov_lower) to_latent = np.ones(linop.shape[1]) to_observed = np.ones(linop.shape[0]) - return _conditional.LatentCond( + cond = _conditional.LatentCond( linop, bias, to_latent=to_latent, to_observed=to_observed ) + return cond, None + + return _Linearization(init, step) + + def ode_taylor_1st( + self, ode_order, damp, jvp_probes: int, jvp_probes_seed: int + ) -> _Linearization: + del jvp_probes + del jvp_probes_seed - return linearise_fun_wrapped + def init(): + return None - def ode_taylor_1st(self, ode_order, damp): - def new(fun, rv, /): + def step(fun, rv, state): + del state mean = rv.mean + # TODO: expose this function somehow. This way, we can + # implement custom information operators easily. def constraint(m): a1 = tree_util.ravel_pytree(self.unravel(m)[ode_order])[0] a0 = tree_util.ravel_pytree(fun(*self.unravel(m)[:ode_order]))[0] @@ -69,17 +92,25 @@ def constraint(m): bias = _normal.Normal(fx, cov_lower) to_latent = np.ones(linop.shape[1]) to_observed = np.ones(linop.shape[0]) - return _conditional.LatentCond( + cond = _conditional.LatentCond( linop, bias, to_latent=to_latent, to_observed=to_observed ) + return cond, None - return new + return _Linearization(init, step) - def ode_statistical_1st(self, cubature_fun, damp: float): + def ode_statistical_1st(self, cubature_fun, damp: float) -> _Linearization: cubature_rule = cubature_fun(input_shape=self.ode_shape) linearise_fun = functools.partial(self.slr1, cubature_rule=cubature_rule) - def new(fun, rv, /): + def init(): + return None + + def new(fun, rv, state): + del state + + # TODO: we can make this a lot more general (yet a little less efficient) + # if we mirror the TS1 implementation more closely. def select_0(s): return tree_util.ravel_pytree(self.unravel(s)[0]) @@ -114,17 +145,23 @@ def A(x): bias = _normal.Normal(-mean, cov_lower) to_latent = np.ones(linop.shape[1]) to_observed = np.ones(linop.shape[0]) - return _conditional.LatentCond( + cond = _conditional.LatentCond( linop, bias, to_latent=to_latent, to_observed=to_observed ) + return cond, None - return new + return _Linearization(init, new) - def ode_statistical_0th(self, cubature_fun, damp: float): + def ode_statistical_0th(self, cubature_fun, damp: float) -> _Linearization: cubature_rule = cubature_fun(input_shape=self.ode_shape) linearise_fun = functools.partial(self.slr0, cubature_rule=cubature_rule) - def new(fun, rv, /): + def init(): + return None + + def new(fun, rv, state): + del state + def select_0(s): return tree_util.ravel_pytree(self.unravel(s)[0]) @@ -156,11 +193,12 @@ def select_1(s): bias = _normal.Normal(-mean, cov_lower) to_latent = np.ones(linop.shape[1]) to_observed = np.ones(linop.shape[0]) - return _conditional.LatentCond( + cond = _conditional.LatentCond( linop, bias, to_latent=to_latent, to_observed=to_observed ) + return cond, None - return new + return _Linearization(init, new) @staticmethod def slr1(fn, x, *, cubature_rule): @@ -214,11 +252,63 @@ class IsotropicLinearisation(LinearisationBackend): def __init__(self, unravel): self.unravel = unravel - def ode_taylor_1st(self, ode_order, damp: float): - raise NotImplementedError + def ode_taylor_1st( + self, ode_order, damp: float, jvp_probes: int, jvp_probes_seed: int + ): + if ode_order > 1: + raise ValueError - def ode_taylor_0th(self, ode_order, damp: float): - def linearise_fun_wrapped(fun, rv): + def init(): + return random.prng_key(seed=jvp_probes_seed) + + def step(fun, rv, key): + mean = rv.mean + + def a1(m): + return m[[ode_order], ...] + + linop = functools.jacrev(a1)(mean[..., 0]) + + def vf_flat(u): + return tree_util.ravel_pytree(fun(unravel(u)))[0] + + def select_0(s): + return tree_util.ravel_pytree(self.unravel(s)[0]) + + # Evaluate the linearisation + m0, unravel = select_0(rv.mean) + fx, Jvp = functools.linearize(vf_flat, m0) + + # Estimate the trace using Hutchinson's estimator + # J_trace, jacobian_state = jacobian(Jvp, m0, jacobian_state) + key, subkey = random.split(key, num=2) + sample_shape = (jvp_probes, *m0.shape) + v = random.rademacher(subkey, shape=sample_shape, dtype=m0.dtype) + J_trace = functools.vmap(lambda s: linalg.vector_dot(s, Jvp(s)))(v) + J_trace = J_trace.mean(axis=0) + + # Turn fx and J_trace into an observation model + E0 = functools.jacrev(lambda s: s[[0], ...])(mean[..., 0]) + linop = linop - J_trace * E0 + fx = mean[1, ...] - fx + fx = fx - linop @ mean + cov_lower = damp * np.eye(1) + bias = _normal.Normal(fx, cov_lower) + to_latent = np.ones((linop.shape[1],)) + to_observed = np.ones((linop.shape[0],)) + cond = _conditional.LatentCond( + linop, bias, to_latent=to_latent, to_observed=to_observed + ) + return cond, key + + return _Linearization(init, step) + + def ode_taylor_0th(self, ode_order, damp: float) -> _Linearization: + def init(): + return None + + def step(fun, rv, state): + del state mean = rv.mean def a1(m): @@ -232,11 +322,12 @@ def a1(m): to_latent = np.ones((linop.shape[1],)) to_observed = np.ones((linop.shape[0],)) - return _conditional.LatentCond( + cond = _conditional.LatentCond( linop, bias, to_latent=to_latent, to_observed=to_observed ) + return cond, None - return linearise_fun_wrapped + return _Linearization(init, step) def ode_statistical_0th(self, cubature_fun, damp: float): raise NotImplementedError @@ -249,8 +340,12 @@ class BlockDiagLinearisation(LinearisationBackend): def __init__(self, unravel): self.unravel = unravel - def ode_taylor_0th(self, ode_order, damp: float): - def linearise_fun_wrapped(fun, rv): + def ode_taylor_0th(self, ode_order, damp: float) -> _Linearization: + def init(): + return None + + def step(fun, rv, state): + del state mean = rv.mean fx = tree_util.ravel_pytree(fun(*self.unravel(mean)[:ode_order]))[0] @@ -265,14 +360,65 @@ def a1(s): to_latent = np.ones((linop.shape[2],)) to_observed = np.ones((linop.shape[1],)) - return _conditional.LatentCond( + cond = _conditional.LatentCond( linop, bias, to_latent=to_latent, to_observed=to_observed ) + return cond, None - return linearise_fun_wrapped + return _Linearization(init, step) - def ode_taylor_1st(self, ode_order, damp: float): - raise NotImplementedError + def ode_taylor_1st( + self, ode_order, damp: float, jvp_probes: int, jvp_probes_seed: int + ): + if ode_order > 1: + raise ValueError + + def init(): + return random.prng_key(seed=jvp_probes_seed) + + def step(fun, rv, key): + mean = rv.mean + + def a1(s): + return s[[ode_order], ...] + + linop = functools.vmap(functools.jacrev(a1))(mean) + + def vf_flat(u): + return tree_util.ravel_pytree(fun(unravel(u)))[0] + + def select_0(s): + return tree_util.ravel_pytree(self.unravel(s)[0]) + + # Evaluate the linearisation + m0, unravel = select_0(rv.mean) + fx, Jvp = functools.linearize(vf_flat, m0) + + key, subkey = random.split(key, num=2) + sample_shape = (jvp_probes, *m0.shape) + v = random.rademacher(subkey, shape=sample_shape, dtype=m0.dtype) + J_diag = functools.vmap(lambda s: s * Jvp(s))(v) + J_diag = J_diag.mean(axis=0) + E1 = functools.jacrev(lambda s: s[0])(rv.mean[0]) + linop = linop - J_diag[:, None, None] * E1[None, None, :] + + fx = rv.mean[:, 1] - fx + fx = fx[..., None] + diff = functools.vmap(lambda a, b: a @ b)(linop, rv.mean) + fx = fx - diff + + d, *_ = linop.shape + cov_lower = damp * np.ones((d, 1, 1)) + bias = _normal.Normal(fx, cov_lower) + + to_latent = np.ones((linop.shape[2],)) + to_observed = np.ones((linop.shape[1],)) + cond = _conditional.LatentCond( + linop, bias, to_latent=to_latent, to_observed=to_observed + ) + return cond, key + + return _Linearization(init, step) def ode_statistical_0th(self, cubature_fun, damp: float): raise NotImplementedError diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 56aae40b6..701aa875a 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -28,6 +28,8 @@ def prior_wiener_integrated( ssm = impl.choose(ssm_fact, tcoeffs_like=tcoeffs) # TODO: should the output_scale be an argument to solve()? + # TODO: should the output scale (and all 'damp'-like factors) + # mirror the pytree structure of 'tcoeffs'? if output_scale is None: output_scale = np.ones_like(ssm.prototypes.output_scale()) @@ -445,26 +447,26 @@ class _Correction: name: str ode_order: int ssm: Any - linearize: Callable + linearize: Any vector_field: Callable def init(self, x, /): """Initialise the state from the solution.""" - y = self.ssm.prototypes.observed() - return x, y + jac = self.linearize.init() + return x, jac - def correct(self, rv, /, t): + def correct(self, rv, correction_state, /, t): """Perform the correction step.""" f_wrapped = functools.partial(self.vector_field, t=t) - cond = self.linearize(f_wrapped, rv) + cond, correction_state = self.linearize.update(f_wrapped, rv, correction_state) observed, reverted = self.ssm.conditional.revert(rv, cond) corrected = reverted.noise - return corrected, observed + return corrected, observed, correction_state - def estimate_error(self, rv, /, t): + def estimate_error(self, rv, correction_state, /, t): """Estimate the error.""" f_wrapped = functools.partial(self.vector_field, t=t) - cond = self.linearize(f_wrapped, rv) + cond, correction_state = self.linearize.update(f_wrapped, rv, correction_state) observed = self.ssm.conditional.marginalise(rv, cond) zero_data = np.zeros(()) @@ -472,7 +474,7 @@ def estimate_error(self, rv, /, t): stdev = self.ssm.stats.standard_deviation(observed) error_estimate_unscaled = np.squeeze(stdev) error_estimate = output_scale * error_estimate_unscaled - return error_estimate, observed + return error_estimate, observed, correction_state def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Correction: @@ -487,9 +489,23 @@ def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Cor ) -def correction_ts1(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Correction: +def correction_ts1( + vector_field, + *, + ssm, + ode_order=1, + damp: float = 0.0, + jvp_probes=10, + jvp_probes_seed=1, +) -> _Correction: """First-order Taylor linearisation.""" - linearize = ssm.linearise.ode_taylor_1st(ode_order=ode_order, damp=damp) + assert jvp_probes > 0 + linearize = ssm.linearise.ode_taylor_1st( + ode_order=ode_order, + damp=damp, + jvp_probes=jvp_probes, + jvp_probes_seed=jvp_probes_seed, + ) return _Correction( name="TS1", vector_field=vector_field, @@ -542,6 +558,7 @@ class _State(containers.NamedTuple): t: Any rv: Any strategy_state: Any + correction_state: Any output_scale: Any @@ -611,7 +628,13 @@ def init(self, t, init) -> _State: # TODO: make the init() and extract() an interface. # Then, lots of calibration logic simplifies considerably. calib_state = self.calibration.init() - return _State(t=t, rv=rv, strategy_state=extra, output_scale=calib_state) + return _State( + t=t, + rv=rv, + strategy_state=extra, + correction_state=corr, + output_scale=calib_state, + ) def step(self, state: _State, *, dt): return self.step_implementation(state, dt=dt, calibration=self.calibration) @@ -638,12 +661,30 @@ def interpolate(self, *, t, interp_from: _State, interp_to: _State) -> _InterpRe # Turn outputs into valid states - def _state(t_, x, scale): - return _State(t=t_, rv=x[0], strategy_state=x[1], output_scale=scale) + def _state(t_, x, scale, cs): + return _State( + t=t_, + rv=x[0], + strategy_state=x[1], + correction_state=cs, + output_scale=scale, + ) - step_from = _state(interp_to.t, interp.step_from, interp_to.output_scale) - interpolated = _state(t, interp.interpolated, interp_to.output_scale) - interp_from = _state(t, interp.interp_from, interp_from.output_scale) + step_from = _state( + interp_to.t, + interp.step_from, + interp_to.output_scale, + interp_to.correction_state, + ) + interpolated = _state( + t, interp.interpolated, interp_to.output_scale, interp_to.correction_state + ) + interp_from = _state( + t, + interp.interp_from, + interp_from.output_scale, + interp_from.correction_state, + ) return _InterpRes( step_from=step_from, interpolated=interpolated, interp_from=interp_from ) @@ -667,13 +708,21 @@ def interpolate_at_t1( tmp.interp_from, ) - def _state(t_, s, scale): - return _State(t=t_, rv=s[0], strategy_state=s[1], output_scale=scale) + def _state(t_, x, scale, cs): + return _State( + t=t_, + rv=x[0], + strategy_state=x[1], + correction_state=cs, + output_scale=scale, + ) t = interp_to.t - prev = _state(t, interp_from_, interp_from.output_scale) - sol = _state(t, solution_, interp_to.output_scale) - acc = _state(t, step_from_, interp_to.output_scale) + prev = _state( + t, interp_from_, interp_from.output_scale, interp_from.correction_state + ) + sol = _state(t, solution_, interp_to.output_scale, interp_to.correction_state) + acc = _state(t, step_from_, interp_to.output_scale, interp_to.correction_state) return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) @@ -693,7 +742,9 @@ def step_mle(state, /, *, dt, calibration): mean = ssm.stats.mean(state.rv) mean_extra = ssm.conditional.apply(mean, transition) t = state.t + dt - error, _ = correction.estimate_error(mean_extra, t=t) + error, _, correction_state = correction.estimate_error( + mean_extra, state.correction_state, t=t + ) # Do the full prediction step (reuse previous discretisation) hidden, extra = strategy.extrapolate( @@ -701,13 +752,20 @@ def step_mle(state, /, *, dt, calibration): ) # Do the full correction step - hidden, observed = correction.correct(hidden, t=t) + hidden, observed, corr_state = correction.correct(hidden, correction_state, t=t) # Calibrate the output scale output_scale = calibration.update(state.output_scale, observed=observed) # Normalise the error - state = _State(t=t, rv=hidden, strategy_state=extra, output_scale=output_scale) + + state = _State( + t=t, + rv=hidden, + strategy_state=extra, + correction_state=corr_state, + output_scale=output_scale, + ) u_proposed = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0] reference = np.maximum(np.abs(u_proposed), np.abs(u_step_from)) error = _ErrorEstimate(dt * error, reference=reference) @@ -756,7 +814,9 @@ def step_dynamic(state, /, *, dt, calibration): hidden = ssm.conditional.apply(mean, transition) t = state.t + dt - error, observed = correction.estimate_error(hidden, t=t) + error, observed, correction_state = correction.estimate_error( + hidden, state.correction_state, t=t + ) output_scale = calibration.update(state.output_scale, observed=observed) # Do the full extrapolation with the calibrated output scale @@ -767,10 +827,16 @@ def step_dynamic(state, /, *, dt, calibration): ) # Do the full correction step - hidden, corr = correction.correct(hidden, t=t) + hidden, _, correction_state = correction.correct(hidden, correction_state, t=t) # Return solution - state = _State(t=t, rv=hidden, strategy_state=extra, output_scale=output_scale) + state = _State( + t=t, + rv=hidden, + strategy_state=extra, + correction_state=correction_state, + output_scale=output_scale, + ) # Normalise the error u_proposed = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0] @@ -815,7 +881,9 @@ def step(state: _State, *, dt, calibration): mean = ssm.stats.mean(state.rv) hidden = ssm.conditional.apply(mean, transition) t = state.t + dt - error, _ = correction.estimate_error(hidden, t=t) + error, _, correction_state = correction.estimate_error( + hidden, state.correction_state, t=t + ) # Do the full extrapolation step (reuse the transition) hidden, extra = strategy.extrapolate( @@ -823,9 +891,13 @@ def step(state: _State, *, dt, calibration): ) # Do the full correction step - hidden, corr = correction.correct(hidden, t=t) + hidden, _, correction_state = correction.correct(hidden, correction_state, t=t) state = _State( - t=t, rv=hidden, strategy_state=extra, output_scale=state.output_scale + t=t, + rv=hidden, + strategy_state=extra, + correction_state=correction_state, + output_scale=state.output_scale, ) # Normalise the error @@ -960,15 +1032,15 @@ class _RejectionState(containers.NamedTuple): step_from: Any def init(s0: _AdaState) -> _RejectionState: - def _inf_like(tree): - return tree_util.tree_map(lambda x: np.inf() * np.ones_like(x), tree) + def _ones_like(tree): + return tree_util.tree_map(np.ones_like, tree) smaller_than_1 = 0.9 # the cond() must return True return _RejectionState( error_norm_proposed=smaller_than_1, dt=s0.dt, control=s0.control, - proposed=_inf_like(s0.step_from), + proposed=_ones_like(s0.step_from), # irrelevant step_from=s0.step_from, )