diff --git a/docs/examples_advanced/solve_pde.py b/docs/examples_advanced/solve_pde.py new file mode 100644 index 000000000..c12419970 --- /dev/null +++ b/docs/examples_advanced/solve_pde.py @@ -0,0 +1,133 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.15.2 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# # Solve a PDE +# +# This tutorial replicates Figure 1 from https://arxiv.org/abs/2110.11812, +# but uses some advanced features in Probdiffeq, namely, solving matrix-valued problems +# and adaptive simulation with fixedpoint smoothing. + +# + +"""Solve a PDE.""" + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +from probdiffeq import ivpsolve, ivpsolvers, taylor + +jax.config.update("jax_enable_x64", True) + + +def main(): + """Simulate a PDE.""" + key = jax.random.PRNGKey(1) + f, (u0,), (t0, t1) = fhn_2d(key, num=40, t1=10.0) + + @jax.jit + def vf(y, *, t): # noqa: ARG001 + """Evaluate the dynamics of the PDE.""" + return f(y) + + print("Problem dimension:", u0.size) + + # Set up a state-space model + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1) + init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="blockdiag") + + # Build a solver + ts = ivpsolvers.correction_ts1(vf, ssm=ssm) + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) + solver = ivpsolvers.solver_dynamic( + ssm=ssm, strategy=strategy, prior=ibm, correction=ts + ) + adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm) + + # Solve the ODE + save_at = jnp.linspace(t0, t1, num=5, endpoint=True) + simulate = simulator(save_at=save_at, adaptive_solver=adaptive_solver, ssm=ssm) + (u, u_std) = simulate(init) + + fig, axes = plt.subplots( + nrows=2, ncols=len(u), figsize=(2 * len(u), 3), tight_layout=True + ) + for t_i, u_i, std_i, ax_i in zip(save_at, u, u_std, axes.T): + ax_i[0].set_title(f"t = {t_i:.1f}") + img = ax_i[0].imshow(u_i[0], cmap="copper", vmin=-1, vmax=1) + plt.colorbar(img) + + uncertainty = jnp.log10(jnp.abs(std_i[0]) + 1e-10) + img = ax_i[1].imshow(uncertainty, cmap="bone", vmin=-7, vmax=-3) + plt.colorbar(img) + + ax_i[0].set_xticks(()) + ax_i[1].set_xticks(()) + ax_i[0].set_yticks(()) + ax_i[1].set_yticks(()) + + axes[0][0].set_ylabel("PDE solution") + axes[1][0].set_ylabel("log(stdev)") + plt.show() + + +def simulator(save_at, adaptive_solver, ssm): + """Simulate a PDE.""" + + @jax.jit + def solve(init): + solution = ivpsolve.solve_adaptive_save_at( + init, save_at=save_at, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm + ) + return (solution.u[0], solution.u_std[0]) + + return solve + + +def fhn_2d(prng_key, *, num, t1, t0=0.0, a=2.8e-4, b=5e-3, k=-0.005, tau=1.0): + """Construct the FitzHugh-Nagumo PDE. + + Source: https://github.com/pnkraemer/tornadox/blob/main/tornadox/ivp.py + + But simplified since Probdiffeq can handle matrix-valued ODEs. + Here, we also set tau = 1.0 to make the example quick to execute. + """ + y0 = jax.random.uniform(prng_key, shape=(2, num, num)) + + @jax.jit + def fhn_2d(x): + u, v = x + du = _laplace_2d(u, dx=1.0 / num) + dv = _laplace_2d(v, dx=1.0 / num) + u_new = a * du + u - u**3 - v + k + v_new = (b * dv + u - v) / tau + return jnp.stack((u_new, v_new)) + + return fhn_2d, (y0,), (t0, t1) + + +def _laplace_2d(grid, dx): + """2D Laplace operator on a vectorized 2d grid.""" + # Set the boundary values to the nearest interior node + # This enforces Neumann conditions. + padded_grid = jnp.pad(grid, pad_width=1, mode="edge") + + # Laplacian via convolve2d() + kernel = jnp.array([[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]]) + kernel /= dx**2 + grid = jax.scipy.signal.convolve2d(padded_grid, kernel, mode="same") + return grid[1:-1, 1:-1] + + +if __name__ == "__main__": + main() diff --git a/mkdocs.yml b/mkdocs.yml index 0e44fa1f1..4bc145324 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -89,6 +89,7 @@ nav: - examples_advanced/parameter_estimation_blackjax.ipynb - examples_advanced/neural_ode.ipynb - examples_advanced/equinox_while_loop.ipynb + - examples_advanced/solve_pde.ipynb - API DOCUMENTATION: - ivpsolve: api_docs/ivpsolve.md - ivpsolvers: api_docs/ivpsolvers.md diff --git a/probdiffeq/impl/_linearise.py b/probdiffeq/impl/_linearise.py index a6bd407c0..7a05f65d7 100644 --- a/probdiffeq/impl/_linearise.py +++ b/probdiffeq/impl/_linearise.py @@ -346,6 +346,7 @@ def init(): def step(fun, rv, state): del state + mean = rv.mean fx = tree_util.ravel_pytree(fun(*self.unravel(mean)[:ode_order]))[0] diff --git a/probdiffeq/impl/_normal.py b/probdiffeq/impl/_normal.py index b77125a66..b94aaac42 100644 --- a/probdiffeq/impl/_normal.py +++ b/probdiffeq/impl/_normal.py @@ -61,7 +61,8 @@ def from_tcoeffs(self, tcoeffs: list, damp: float = 0.0): c_sqrtm0_corrected = linalg.diagonal_matrix(damp**powers) leaves, _ = tree_util.tree_flatten(tcoeffs) - m0_corrected = np.stack(leaves) + leaves_flat = tree_util.tree_map(lambda s: tree_util.ravel_pytree(s)[0], leaves) + m0_corrected = np.stack(leaves_flat) return Normal(m0_corrected, c_sqrtm0_corrected) def preconditioner_apply(self, rv, p, /): @@ -83,7 +84,8 @@ def from_tcoeffs(self, tcoeffs: list, damp: float = 0.0): cholesky = np.ones((*self.ode_shape, 1, 1)) * cholesky[None, ...] leaves, _ = tree_util.tree_flatten(tcoeffs) - mean = np.stack(leaves).T + leaves_flat = tree_util.tree_map(lambda s: tree_util.ravel_pytree(s)[0], leaves) + mean = np.stack(leaves_flat).T return Normal(mean, cholesky) def preconditioner_apply(self, rv, p, /): diff --git a/probdiffeq/impl/impl.py b/probdiffeq/impl/impl.py index 98cfc2abe..238df628b 100644 --- a/probdiffeq/impl/impl.py +++ b/probdiffeq/impl/impl.py @@ -75,7 +75,13 @@ def _select_isotropic(*, tcoeffs_like) -> FactImpl: tcoeffs_tree_only = tree_util.tree_map(lambda *_a: 0.0, tcoeffs_like) _, unravel_tree = tree_util.ravel_pytree(tcoeffs_tree_only) - unravel = functools.vmap(unravel_tree, in_axes=1, out_axes=0) + + leaves, _ = tree_util.tree_flatten(tcoeffs_like) + _, unravel_leaf = tree_util.ravel_pytree(leaves[0]) + + def unravel(z): + tree = functools.vmap(unravel_tree, in_axes=1, out_axes=0)(z) + return tree_util.tree_map(unravel_leaf, tree) prototypes = _prototypes.IsotropicPrototype(ode_shape=ode_shape) normal = _normal.IsotropicNormal(ode_shape=ode_shape) @@ -102,7 +108,13 @@ def _select_blockdiag(*, tcoeffs_like) -> FactImpl: tcoeffs_tree_only = tree_util.tree_map(lambda *_a: 0.0, tcoeffs_like) _, unravel_tree = tree_util.ravel_pytree(tcoeffs_tree_only) - unravel = functools.vmap(unravel_tree) + + leaves, _ = tree_util.tree_flatten(tcoeffs_like) + _, unravel_leaf = tree_util.ravel_pytree(leaves[0]) + + def unravel(z): + tree = functools.vmap(unravel_tree, in_axes=0, out_axes=0)(z) + return tree_util.tree_map(unravel_leaf, tree) prototypes = _prototypes.BlockDiagPrototype(ode_shape=ode_shape) normal = _normal.BlockDiagNormal(ode_shape=ode_shape) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 701aa875a..9426d347f 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -449,20 +449,13 @@ class _Correction: ssm: Any linearize: Any vector_field: Callable + re_linearize: bool def init(self, x, /): """Initialise the state from the solution.""" jac = self.linearize.init() return x, jac - def correct(self, rv, correction_state, /, t): - """Perform the correction step.""" - f_wrapped = functools.partial(self.vector_field, t=t) - cond, correction_state = self.linearize.update(f_wrapped, rv, correction_state) - observed, reverted = self.ssm.conditional.revert(rv, cond) - corrected = reverted.noise - return corrected, observed, correction_state - def estimate_error(self, rv, correction_state, /, t): """Estimate the error.""" f_wrapped = functools.partial(self.vector_field, t=t) @@ -474,7 +467,21 @@ def estimate_error(self, rv, correction_state, /, t): stdev = self.ssm.stats.standard_deviation(observed) error_estimate_unscaled = np.squeeze(stdev) error_estimate = output_scale * error_estimate_unscaled - return error_estimate, observed, correction_state + return error_estimate, observed, (correction_state, cond) + + def correct(self, rv, correction_state, /, t): + """Perform the correction step.""" + linearization_state, cond = correction_state + + if self.re_linearize: + f_wrapped = functools.partial(self.vector_field, t=t) + cond, linearization_state = self.linearize.update( + f_wrapped, rv, linearization_state + ) + + observed, reverted = self.ssm.conditional.revert(rv, cond) + corrected = reverted.noise + return corrected, observed, linearization_state def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Correction: @@ -486,6 +493,7 @@ def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Cor ode_order=ode_order, ssm=ssm, linearize=linearize, + re_linearize=False, ) @@ -512,6 +520,7 @@ def correction_ts1( ode_order=ode_order, ssm=ssm, linearize=linearize, + re_linearize=False, ) @@ -526,6 +535,7 @@ def correction_slr0( ode_order=1, linearize=linearize, name="SLR0", + re_linearize=True, ) @@ -540,6 +550,7 @@ def correction_slr1( ode_order=1, linearize=linearize, name="SLR1", + re_linearize=True, )