diff --git a/docs/benchmarks/hires/plot.md b/docs/benchmarks/hires/plot.md index c71afb8f6..7149d0c72 100644 --- a/docs/benchmarks/hires/plot.md +++ b/docs/benchmarks/hires/plot.md @@ -24,7 +24,6 @@ import jax.numpy as jnp import matplotlib.pyplot as plt import jax -from probdiffeq.util.doc_util import notebook jax.config.update("jax_platform_name", "cpu") ``` @@ -91,11 +90,6 @@ def plot_solution(axis, ts, ys, yscale="linear"): return axis ``` -```python -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) -``` - ```python layout = [ ["benchmark", "benchmark", "solution"], diff --git a/docs/benchmarks/lotkavolterra/plot.md b/docs/benchmarks/lotkavolterra/plot.md index 2f4121748..a2b4d1bb2 100644 --- a/docs/benchmarks/lotkavolterra/plot.md +++ b/docs/benchmarks/lotkavolterra/plot.md @@ -24,8 +24,6 @@ import jax.numpy as jnp import matplotlib.pyplot as plt import jax -from probdiffeq.util.doc_util import notebook - jax.config.update("jax_platform_name", "cpu") ``` @@ -95,11 +93,6 @@ def plot_solution(axis, ts, ys, yscale="linear"): return axis ``` -```python -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) -``` - ```python layout = [ ["benchmark", "benchmark", "solution"], diff --git a/docs/benchmarks/pleiades/plot.md b/docs/benchmarks/pleiades/plot.md index 09b8462d9..e63204db2 100644 --- a/docs/benchmarks/pleiades/plot.md +++ b/docs/benchmarks/pleiades/plot.md @@ -24,8 +24,6 @@ import jax.numpy as jnp import matplotlib.pyplot as plt import jax -from probdiffeq.util.doc_util import notebook - jax.config.update("jax_platform_name", "cpu") ``` @@ -101,10 +99,6 @@ def plot_solution(axis, ys, yscale="linear"): return axis ``` -```python -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) -``` ```python layout = [ diff --git a/docs/benchmarks/taylor_fitzhughnagumo/plot.md b/docs/benchmarks/taylor_fitzhughnagumo/plot.md index bed198cac..761ca492a 100644 --- a/docs/benchmarks/taylor_fitzhughnagumo/plot.md +++ b/docs/benchmarks/taylor_fitzhughnagumo/plot.md @@ -24,8 +24,6 @@ import jax.numpy as jnp import matplotlib.pyplot as plt import jax -from probdiffeq.util.doc_util import notebook - jax.config.update("jax_platform_name", "cpu") ``` @@ -88,11 +86,6 @@ def _adaptive_repeat(xs, ys): return jnp.asarray(zs) ``` -```python -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) -``` - ```python fig, (axis_perform, axis_compile) = plt.subplots( ncols=2, figsize=(8, 3), dpi=150, sharex=True diff --git a/docs/benchmarks/taylor_node/plot.md b/docs/benchmarks/taylor_node/plot.md index 37737cf92..01e6d9b93 100644 --- a/docs/benchmarks/taylor_node/plot.md +++ b/docs/benchmarks/taylor_node/plot.md @@ -22,8 +22,6 @@ import jax.numpy as jnp import matplotlib.pyplot as plt import jax -from probdiffeq.util.doc_util import notebook - jax.config.update("jax_platform_name", "cpu") ``` @@ -85,10 +83,6 @@ def _adaptive_repeat(xs, ys): return jnp.asarray(zs) ``` -```python -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) -``` ```python fig, (axis_perform, axis_compile) = plt.subplots( diff --git a/docs/benchmarks/taylor_pleiades/plot.md b/docs/benchmarks/taylor_pleiades/plot.md index 4f06e3f2c..95bd56fab 100644 --- a/docs/benchmarks/taylor_pleiades/plot.md +++ b/docs/benchmarks/taylor_pleiades/plot.md @@ -24,7 +24,6 @@ import jax.numpy as jnp import matplotlib.pyplot as plt import jax -from probdiffeq.util.doc_util import notebook jax.config.update("jax_platform_name", "cpu") ``` @@ -66,11 +65,6 @@ def plot_results(axis_compile, axis_perform, results): return axis_compile, axis_perform ``` -```python -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) -``` - ```python fig, (axis_perform, axis_compile) = plt.subplots( ncols=2, dpi=150, sharex=True, figsize=(8, 3) diff --git a/docs/benchmarks/vanderpol/plot.md b/docs/benchmarks/vanderpol/plot.md index c7300f3aa..fd5f555ec 100644 --- a/docs/benchmarks/vanderpol/plot.md +++ b/docs/benchmarks/vanderpol/plot.md @@ -24,8 +24,6 @@ import jax.numpy as jnp import matplotlib.pyplot as plt import jax -from probdiffeq.util.doc_util import notebook - jax.config.update("jax_platform_name", "cpu") ``` @@ -114,10 +112,6 @@ def plot_solution(axis, ts, ys, yscale="linear"): return axis ``` -```python -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) -``` ```python layout = [ diff --git a/docs/getting_started/choosing_a_solver.md b/docs/choosing_a_solver.md similarity index 100% rename from docs/getting_started/choosing_a_solver.md rename to docs/choosing_a_solver.md diff --git a/docs/examples_advanced/use_equinox_bounded_while_loop.py b/docs/examples_advanced/equinox_while_loop.py similarity index 100% rename from docs/examples_advanced/use_equinox_bounded_while_loop.py rename to docs/examples_advanced/equinox_while_loop.py diff --git a/docs/examples_advanced/neural_ode.py b/docs/examples_advanced/neural_ode.py index bf4ce3fb6..0cf5480b1 100644 --- a/docs/examples_advanced/neural_ode.py +++ b/docs/examples_advanced/neural_ode.py @@ -12,198 +12,188 @@ # name: python3 # --- -# # Neural ODEs +# # Diffusion tempering & NODEs # -# We can use the parameter estimation functionality -# to fit a neural ODE to a time series data set. # + -"""Train a neural ODE with ProbDiffEq and Optax.""" +"""Train a neural ODE with ProbDiffEq and Optax using diffusion tempering.""" import jax import jax.numpy as jnp import matplotlib.pyplot as plt import optax -from diffeqzoo import backend, ivps from probdiffeq import ivpsolve, ivpsolvers, stats -from probdiffeq.util.doc_util import notebook -# - -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) +def main(num_data=100, epochs=1_000, print_every=100, hidden=(20,), lr=0.2): + """Train a neural ODE using diffusion tempering.""" + # Create some data and construct a neural ODE + grid = jnp.linspace(0, 1, num=num_data) + data = jnp.sin(2.5 * jnp.pi * grid) * jnp.pi * grid + stdev = 1e-1 + output_scale = 1e2 + vf, u0, (t0, t1), f_args = vf_neural_ode(hidden=hidden, t0=0.0, t1=1) -# + -if not backend.has_been_selected: - backend.select("jax") # ivp examples in jax + # Create a loss (this is where probabilistic numerics enters!) + loss = loss_log_marginal_likelihood(vf=vf, t0=t0) + loss0, info0 = loss( + f_args, u0=u0, grid=grid, data=data, stdev=stdev, output_scale=output_scale + ) + + # Plot the data and the initial guess + plt.title(f"Initial estimate | Loss: {loss0:.2f}") + plt.plot(grid, data, "x", label="Data", color="C0") + plt.plot(grid, info0["sol"].u[0], "-", label="Estimate", color="C1") + plt.legend() + plt.show() + + # Construct an optimiser + optim = optax.adam(lr) + train_step = train_step_optax(optim, loss=loss) + + # Train the model + state = optim.init(f_args) + print("Loss after...") + for i in range(epochs): + (f_args, state), info = train_step( + f_args, + state, + u0=u0, + grid=grid, + data=data, + stdev=stdev, + output_scale=output_scale, + ) -# Catch NaN gradients in CI -# Disable to improve speed -jax.config.update("jax_debug_nans", True) + # Print progressbar + if i % print_every == print_every - 1: + print(f"...{(i + 1)} epochs: loss={info['loss']:.3e}") -jax.config.update("jax_platform_name", "cpu") -# - + # Diffusion tempering: https://arxiv.org/abs/2402.12231 + # To all users: Adjust this tempering and + # see how it affects parameter estimation. + if i % 100 == 0: + output_scale /= 10.0 + # Plot the results + plt.title(f"Final estimate | Loss: {info['loss']:.2f}") + plt.plot(grid, data, "x", label="Data", color="C0") + plt.plot(grid, info0["sol"].u[0], "-", label="Initial estimate", color="C1") + plt.plot(grid, info["sol"].u[0], "-", label="Final estimate", color="C2") + plt.legend() + plt.show() -# To keep the problem nice and small, assume that the data set is a -# trigonometric function (which solve differential equations). -# + -grid = jnp.linspace(0, 1, num=100) -data = jnp.sin(5 * jnp.pi * grid) +def vf_neural_ode(*, hidden: tuple, t0: float, t1: float): + """Build a neural ODE.""" + f_args, mlp = model_mlp(hidden=hidden, shape_in=(2,), shape_out=(1,)) + u0 = jnp.asarray([0.0]) + + @jax.jit + def vf(y, *, t, p): + """Evaluate the neural ODE vector field.""" + y_and_t = jnp.concatenate([y, t[None]]) + return mlp(p, y_and_t) + + return vf, (u0,), (t0, t1), f_args + + +def model_mlp( + *, hidden: tuple, shape_in: tuple = (), shape_out: tuple = (), activation=jnp.tanh +): + """Construct an MLP.""" + assert len(shape_in) <= 1 + assert len(shape_out) <= 1 + + shape_prev = shape_in + weights = [] + for h in hidden: + W = jnp.empty((h, *shape_prev)) + b = jnp.empty((h,)) + shape_prev = (h,) + weights.append((W, b)) -plt.plot(grid, data, ".-", label="Data") -plt.legend() -plt.show() + W = jnp.empty((*shape_out, *shape_prev)) + b = jnp.empty(shape_out) + weights.append((W, b)) + p_flat, unravel = jax.flatten_util.ravel_pytree(weights) -# - + def fwd(w, x): + for A, b in w[:-1]: + x = jnp.dot(A, x) + b + x = activation(x) + A, b = w[-1] + return jnp.dot(A, x) + b -def build_loss_fn(vf, initial_values, solver, *, standard_deviation=1e-2): + key = jax.random.PRNGKey(1) + p_init = jax.random.normal(key, shape=p_flat.shape, dtype=p_flat.dtype) + return unravel(p_init), fwd + + +def loss_log_marginal_likelihood(vf, *, t0): """Build a loss function from an ODE problem.""" @jax.jit - def loss_fn(parameters): + def loss( + p: jax.Array, + *, + u0: tuple, + grid: jax.Array, + data: jax.Array, + stdev: jax.Array, + output_scale: jax.Array, + ): """Loss function: log-marginal likelihood of the data.""" - tcoeffs = (*initial_values, vf(*initial_values, t=t0, p=parameters)) - ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic") + # Build a solver + tcoeffs = (*u0, vf(*u0, t=t0, p=p)) + ibm, ssm = ivpsolvers.prior_ibm( + tcoeffs, output_scale=output_scale, ssm_fact="isotropic" + ) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_smoother(ssm=ssm) solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) - init = solver_ts0.initial_condition() + # Solve + init = solver_ts0.initial_condition() sol = ivpsolve.solve_fixed_grid( - lambda *a, **kw: vf(*a, **kw, p=parameters), + lambda *a, **kw: vf(*a, **kw, p=p), init, grid=grid, - solver=solver, + solver=solver_ts0, ssm=ssm, ) - observation_std = jnp.ones_like(grid) * standard_deviation + # Evaluate loss marginal_likelihood = stats.log_marginal_likelihood( data[:, None], - standard_deviation=observation_std, + standard_deviation=jnp.ones_like(grid) * stdev, posterior=sol.posterior, ssm=sol.ssm, ) - return -1 * marginal_likelihood + return -1 * marginal_likelihood, {"sol": sol} - return loss_fn + return loss -def build_update_fn(*, optimizer, loss_fn): - """Build a function for executing a single step in the optimization.""" +def train_step_optax(optimizer, loss): + """Implement a training step using Optax.""" @jax.jit - def update(params, opt_state): + def update(params, opt_state, **loss_kwargs): """Update the optimiser state.""" - _loss, grads = jax.value_and_grad(loss_fn)(params) + value_and_grad = jax.value_and_grad(loss, argnums=0, has_aux=True) + (value, info), grads = value_and_grad(params, **loss_kwargs) updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) - return params, opt_state - return update - - -# ## Construct an MLP with tanh activation -# -# Let's start with the example given in the -# [implicit layers tutorial](http://implicit-layers-tutorial.org/neural_odes/). -# The vector field is provided by [DiffEqZoo](https://diffeqzoo.readthedocs.io/). - -# + -f, u0, (t0, t1), f_args = ivps.neural_ode_mlp(layer_sizes=(2, 20, 1)) - - -@jax.jit -def vf(y, *, t, p): - """Evaluate the MLP.""" - return f(y, t, *p) - - -# Make a solver -tcoeffs = (u0, vf(u0, t=t0, p=f_args)) -ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") -ts0 = ivpsolvers.correction_ts0(ssm=ssm) -strategy = ivpsolvers.strategy_smoother(ssm=ssm) -solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) -init = solver_ts0.initial_condition() - -# + -sol = ivpsolve.solve_fixed_grid( - lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0, ssm=ssm -) - -plt.plot(sol.t, sol.u[0], ".-", label="Initial estimate") -plt.plot(grid, data, ".-", label="Data") -plt.legend() -plt.show() -# - - -# ## Set up a loss function and an optimiser -# -# Like in the other tutorials, we use [Optax](https://optax.readthedocs.io/en/latest/index.html). - -loss_fn = build_loss_fn(vf=vf, initial_values=(u0,), solver=solver_ts0) -optim = optax.adam(learning_rate=2e-2) -update_fn = build_update_fn(optimizer=optim, loss_fn=loss_fn) - -p = f_args -state = optim.init(p) -chunk_size = 25 -for i in range(chunk_size): - for _ in range(chunk_size**2): - p, state = update_fn(p, state) - print( - "Negative log-marginal-likelihood after " - f"{(i + 1) * chunk_size**2}/{chunk_size**3} steps:", - loss_fn(p), - ) - -# + -plt.plot(sol.t, data, "-", linewidth=5, alpha=0.5, label="Data") -tcoeffs = (u0, vf(u0, t=t0, p=f_args)) -ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") -ts0 = ivpsolvers.correction_ts0(ssm=ssm) -strategy = ivpsolvers.strategy_smoother(ssm=ssm) -solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) -init = solver_ts0.initial_condition() - -sol = ivpsolve.solve_fixed_grid( - lambda *a, **kw: vf(*a, **kw, p=p), init, grid=grid, solver=solver_ts0, ssm=ssm -) - - -plt.plot(sol.t, sol.u[0], ".-", label="Final guess") - -tcoeffs = (u0, vf(u0, t=t0, p=f_args)) -ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") -ts0 = ivpsolvers.correction_ts0(ssm=ssm) -strategy = ivpsolvers.strategy_smoother(ssm=ssm) -solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) -init = solver_ts0.initial_condition() - -sol = ivpsolve.solve_fixed_grid( - lambda *a, **kw: vf(*a, **kw, p=f_args), init, grid=grid, solver=solver_ts0, ssm=ssm -) -plt.plot(sol.t, sol.u[0], ".-", label="Initial guess") + info["loss"] = value + return (params, opt_state), info + return update -plt.legend() -plt.show() -# - -# ## What's next -# -# -# The same example can be constructed with deep learning libraries -# such as [Equinox](https://docs.kidger.site/equinox/), -# [Haiku](https://dm-haiku.readthedocs.io/en/latest/), or -# [Flax](https://flax.readthedocs.io/en/latest/getting_started.html). -# To do so, define a corresponding vector field and a parameter set, -# build a new loss function and repeat. -# -# +if __name__ == "__main__": + main() diff --git a/docs/examples_advanced/physics_enhanced_regression_2.py b/docs/examples_advanced/parameter_estimation_blackjax.py similarity index 98% rename from docs/examples_advanced/physics_enhanced_regression_2.py rename to docs/examples_advanced/parameter_estimation_blackjax.py index 500da49d7..52203aa82 100644 --- a/docs/examples_advanced/physics_enhanced_regression_2.py +++ b/docs/examples_advanced/parameter_estimation_blackjax.py @@ -133,7 +133,6 @@ from diffeqzoo import backend, ivps from probdiffeq import ivpsolve, ivpsolvers, stats, taylor -from probdiffeq.util.doc_util import notebook # + # x64 precision @@ -146,9 +145,6 @@ if not backend.has_been_selected: backend.select("jax") -# Nice-looking plots -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) # - diff --git a/docs/examples_advanced/physics_enhanced_regression_1.py b/docs/examples_advanced/parameter_estimation_optax.py similarity index 90% rename from docs/examples_advanced/physics_enhanced_regression_1.py rename to docs/examples_advanced/parameter_estimation_optax.py index c30bf36c9..13a334f52 100644 --- a/docs/examples_advanced/physics_enhanced_regression_1.py +++ b/docs/examples_advanced/parameter_estimation_optax.py @@ -14,16 +14,13 @@ # # Parameter estimation (Optax) # -# **Time-series data and optimization with ``optax``** -# -# We create some fake-observational data, -# compute the marginal likelihood of this fake data _under the ODE posterior_ +# We create some data, +# compute the marginal likelihood of this data _under the ODE posterior_ # (which is something you cannot do with non-probabilistic solvers!), # and optimize the parameters with `optax`. # +# Link to paper: https://arxiv.org/abs/2202.01287 # -# Tronarp, Bosch, and Hennig call this "physics-enhanced regression" -# ([link to paper](https://arxiv.org/abs/2202.01287)). # + """Estimate ODE parameters with ProbDiffEq and Optax.""" @@ -35,12 +32,6 @@ from diffeqzoo import backend, ivps from probdiffeq import ivpsolve, ivpsolvers, stats -from probdiffeq.util.doc_util import notebook - -# - - -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) # + if not backend.has_been_selected: diff --git a/docs/examples_basic/conditioning-on-zero-residual.py b/docs/examples_basic/conditioning-on-zero-residual.py index 0dbe57cdf..c639731e1 100644 --- a/docs/examples_basic/conditioning-on-zero-residual.py +++ b/docs/examples_basic/conditioning-on-zero-residual.py @@ -27,12 +27,6 @@ from diffeqzoo import backend from probdiffeq import ivpsolve, ivpsolvers, stats, taylor -from probdiffeq.util.doc_util import notebook - -# - - -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) # + if not backend.has_been_selected: diff --git a/docs/examples_basic/dynamic_output_scales.py b/docs/examples_basic/dynamic_output_scales.py index 4724bd7f6..cb3abf356 100644 --- a/docs/examples_basic/dynamic_output_scales.py +++ b/docs/examples_basic/dynamic_output_scales.py @@ -40,12 +40,6 @@ from diffeqzoo import backend, ivps from probdiffeq import ivpsolve, ivpsolvers -from probdiffeq.util.doc_util import notebook - -# - - -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) # + if not backend.has_been_selected: diff --git a/docs/examples_basic/posterior_uncertainties.py b/docs/examples_basic/posterior_uncertainties.py index 2d2345f26..96c594756 100644 --- a/docs/examples_basic/posterior_uncertainties.py +++ b/docs/examples_basic/posterior_uncertainties.py @@ -17,167 +17,76 @@ # + """Display the marginal uncertainties of filters and smoothers.""" -import jax import jax.numpy as jnp import matplotlib.pyplot as plt -from diffeqzoo import backend, ivps from probdiffeq import ivpsolve, ivpsolvers, stats, taylor -from probdiffeq.util.doc_util import notebook -# - +# Set up the ODE -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) -# + -if not backend.has_been_selected: - backend.select("jax") # ivp examples in jax - -jax.config.update("jax_enable_x64", True) -jax.config.update("jax_platform_name", "cpu") -# - - -# Set an example problem. -# -# Solve the problem on a low resolution and -# short time-span to achieve large uncertainty. - -# + -f, u0, (t0, t1), f_args = ivps.lotka_volterra() - - -@jax.jit -def vf(*ys, t): # noqa: ARG001 +def vf(y, *, t): # noqa: ARG001 """Evaluate the Lotka-Volterra vector field.""" - return f(*ys, *f_args) - + y0, y1 = y[0], y[1] -# - + y0_new = 0.5 * y0 - 0.05 * y0 * y1 + y1_new = -0.5 * y1 + 0.05 * y0 * y1 + return jnp.asarray([y0_new, y1_new]) -# ## Filter -# + -tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) -ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") -ts0 = ivpsolvers.correction_ts0(ssm=ssm) -strategy = ivpsolvers.strategy_filter(ssm=ssm) -solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) -adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) +t0 = 0.0 +t1 = 2.0 +u0 = jnp.asarray([20.0, 20.0]) -ts = jnp.linspace(t0, t0 + 2.0, endpoint=True, num=500) -# + -dt0 = ivpsolve.dt0(lambda y: vf(y, t=t0), (u0,)) +# Set up a solver +# To all users: Try replacing the fixedpoint-smoother with a filter! +tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=3) +ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense") +ts = ivpsolvers.correction_ts1(ssm=ssm) +strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) +solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts, ssm=ssm) +adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-1, rtol=1e-1, ssm=ssm) +# Solve the ODE +ts = jnp.linspace(t0, t1, endpoint=True, num=50) init = solver.initial_condition() sol = ivpsolve.solve_adaptive_save_at( - vf, init, save_at=ts, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm + vf, init, save_at=ts, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm ) +# Calibrate marginals = stats.calibrate(sol.marginals, output_scale=sol.output_scale, ssm=ssm) -# - +std = ssm.stats.standard_deviation(marginals) +u_std = ssm.stats.qoi_from_sample(std) # Plot the solution - -# + -_, num_derivatives, _ = marginals.mean.shape - - -fig, axes_all = plt.subplots( - nrows=2, ncols=num_derivatives, sharex=True, tight_layout=True, figsize=(8, 3) -) - -for i, axes_cols in enumerate(axes_all.T): - ms = marginals.mean[:, i, :] - ls = marginals.cholesky[:, i, :] - stds = jnp.sqrt(jnp.einsum("jn,jn->j", ls, ls)) - - if i == 1: - axes_cols[0].set_title(f"{i}st deriv.") +fig, axes = plt.subplots(nrows=2, ncols=len(tcoeffs), tight_layout=True, figsize=(8, 3)) +for i, (u_i, std_i, ax_i) in enumerate(zip(sol.u, u_std, axes.T)): + # Set up titles and axis descriptions + if i == 0: + ax_i[0].set_title("State") + ax_i[0].set_ylabel("Predators") + ax_i[1].set_ylabel("Prey") + elif i == 1: + ax_i[0].set_title(f"{i}st deriv.") elif i == 2: - axes_cols[0].set_title(f"{i}nd deriv.") + ax_i[0].set_title(f"{i}nd deriv.") elif i == 3: - axes_cols[0].set_title(f"{i}rd deriv.") + ax_i[0].set_title(f"{i}rd deriv.") else: - axes_cols[0].set_title(f"{i}th deriv.") - - axes_cols[0].plot(sol.t, ms, marker="None") - for m in ms.T: - axes_cols[0].fill_between(sol.t, m - 1.96 * stds, m + 1.96 * stds, alpha=0.3) - - axes_cols[1].semilogy(sol.t, stds, marker="None") - -plt.show() -# - + ax_i[0].set_title(f"{i}th deriv.") -# ## Smoother - -# + -ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") -ts0 = ivpsolvers.correction_ts0(ssm=ssm) -strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) -solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) -adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) - -ts = jnp.linspace(t0, t0 + 2.0, endpoint=True, num=500) - -# + -init = solver.initial_condition() -sol = ivpsolve.solve_adaptive_save_at( - vf, init, save_at=ts, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm -) - -marginals = stats.calibrate(sol.marginals, output_scale=sol.output_scale, ssm=ssm) -posterior = stats.calibrate(sol.posterior, output_scale=sol.output_scale, ssm=ssm) -posterior = stats.markov_select_terminal(posterior) -# - - -key = jax.random.PRNGKey(seed=1) -samples, _init = stats.markov_sample(key, posterior, shape=(2,), reverse=True, ssm=ssm) - -# + -_, num_derivatives, _ = marginals.mean.shape - - -fig, axes_all = plt.subplots( - nrows=2, ncols=num_derivatives, sharex=True, tight_layout=True, figsize=(8, 3) -) - -for i, axes_cols in enumerate(axes_all.T): - samps = samples[i] - ms = ssm.stats.qoi_from_sample(marginals.mean)[i] - ls = ssm.stats.qoi_from_sample(marginals.cholesky)[i] - stds = jnp.sqrt(jnp.einsum("jn,jn->j", ls, ls)) - - if i == 1: - axes_cols[0].set_title(f"{i}st deriv.") - elif i == 2: - axes_cols[0].set_title(f"{i}nd deriv.") - elif i == 3: - axes_cols[0].set_title(f"{i}rd deriv.") - else: - axes_cols[0].set_title(f"{i}th deriv.") + ax_i[1].set_xlabel("Time") - axes_cols[0].plot(sol.t, ms, marker="None") - for s in samps: - axes_cols[0].plot( - sol.t[:-1], s[..., 0], color="C0", linewidth=0.35, marker="None" - ) - axes_cols[0].plot( - sol.t[:-1], s[..., 1], color="C1", linewidth=0.35, marker="None" - ) - for m in ms.T: - axes_cols[0].fill_between(sol.t, m - 1.96 * stds, m + 1.96 * stds, alpha=0.3) + for m, std, ax in zip(u_i.T, std_i.T, ax_i): + # Plot the mean + ax.plot(sol.t, m) - axes_cols[1].semilogy(sol.t, stds, marker="None") + # Plot the standard deviation + lower, upper = m - 1.96 * std, m + 1.96 * std + ax.fill_between(sol.t, lower, upper, alpha=0.3) + ax.set_xlim((jnp.amin(ts), jnp.amax(ts))) +fig.align_ylabels() plt.show() -# - - -# The marginal standard deviations (bottom row) -# show how the filter is forward-only, -# whereas the smoother is a global estimate. -# -# This is why you should use a filter for -# terminal-value simulation and a smoother if you want "global" solutions. diff --git a/docs/examples_basic/second_order_problems.py b/docs/examples_basic/second_order_problems.py index b91d60c23..bff53dba7 100644 --- a/docs/examples_basic/second_order_problems.py +++ b/docs/examples_basic/second_order_problems.py @@ -23,12 +23,6 @@ from diffeqzoo import backend, ivps from probdiffeq import ivpsolve, ivpsolvers, taylor -from probdiffeq.util.doc_util import notebook - -# - - -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) # + if not backend.has_been_selected: diff --git a/docs/examples_basic/taylor_coefficients.py b/docs/examples_basic/taylor_coefficients.py index afe7be9ba..80803f227 100644 --- a/docs/examples_basic/taylor_coefficients.py +++ b/docs/examples_basic/taylor_coefficients.py @@ -27,14 +27,9 @@ import jax import jax.numpy as jnp -import matplotlib.pyplot as plt from diffeqzoo import backend, ivps from probdiffeq import ivpsolve, ivpsolvers, stats, taylor -from probdiffeq.util.doc_util import notebook - -plt.rcParams.update(notebook.plot_style()) -plt.rcParams.update(notebook.plot_sizes()) if not backend.has_been_selected: backend.select("jax") # ivp examples in jax diff --git a/docs/examples_quickstart/easy_example.py b/docs/examples_quickstart/easy_example.py deleted file mode 100644 index 575e41a8e..000000000 --- a/docs/examples_quickstart/easy_example.py +++ /dev/null @@ -1,90 +0,0 @@ -# --- -# jupyter: -# jupytext: -# text_representation: -# extension: .py -# format_name: light -# format_version: '1.5' -# jupytext_version: 1.15.2 -# kernelspec: -# display_name: Python 3 (ipykernel) -# language: python -# name: python3 -# --- - -# # Quickstart -# -# Let's have a look at an easy example. - -# + -"""Solve the logistic equation.""" - -import jax -import jax.numpy as jnp - -from probdiffeq import ivpsolve, ivpsolvers, taylor - -jax.config.update("jax_platform_name", "cpu") - - -# - - -# Create a problem: - - -# + -@jax.jit -def vf(y, *, t): # noqa: ARG001 - """Evaluate the vector field.""" - return 2.0 * y * (1 - y) - - -u0 = jnp.asarray([0.1]) -t0, t1 = 0.0, 5.0 -# - - -# Configuring a probabilistic IVP solver is a little more -# involved than configuring your favourite Runge-Kutta method: -# we must choose a prior distribution and a correction scheme, -# then we put them together as a filter or smoother, -# wrap everything into a solver, and (finally) make the solver adaptive. -# - -# + - -# Set up a state-space model -tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) -ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense") -ts0 = ivpsolvers.correction_ts1(ode_order=1, ssm=ssm) -strategy = ivpsolvers.strategy_smoother(ssm=ssm) - -# Build a solver -solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) -adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm) -# - - - -# Other software packages that implement -# probabilistic IVP solvers do a lot of this work -# implicitly; probdiffeq enforces that -# the user makes these decisions, not only because -# it simplifies the solver implementations -# (quite a lot, actually), -# but it also shows how easily we can -# build a custom solver for our favourite problem -# (consult the other tutorials for examples). - -# From here on, the rest is standard ODE-solver machinery: - -# + -# Solve the ODE -init = solver.initial_condition() -dt0 = 0.1 -solution = ivpsolve.solve_adaptive_save_every_step( - vf, init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm -) - -# Look at the solution -print(f"u = {jax.tree.map(jnp.shape, solution.u)}") # Taylor coefficients -print(f"solution = {jax.tree.map(jnp.shape, solution)}") # IVP solution -# - diff --git a/docs/examples_quickstart/quickstart.py b/docs/examples_quickstart/quickstart.py new file mode 100644 index 000000000..5120fe7a0 --- /dev/null +++ b/docs/examples_quickstart/quickstart.py @@ -0,0 +1,61 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.15.2 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# # Quickstart +# +# Let's have a look at an easy example. + +# + +"""Solve the logistic equation.""" + +import jax +import jax.numpy as jnp + +from probdiffeq import ivpsolve, ivpsolvers, taylor + +# Define a differential equation + + +@jax.jit +def vf(y, *, t): # noqa: ARG001 + """Evaluate the dynamics of the logistic ODE.""" + return 2 * y * (1 - y) + + +u0 = jnp.asarray([0.1]) +t0, t1 = 0.0, 5.0 + + +# Set up a state-space model +tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1) +ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense") + + +# Build a solver +ts = ivpsolvers.correction_ts1(ssm=ssm, ode_order=1) +strategy = ivpsolvers.strategy_filter(ssm=ssm) +solver = ivpsolvers.solver_mle(ssm=ssm, strategy=strategy, prior=ibm, correction=ts) +adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm) + + +# Solve the ODE +# To all users: Try different solution routines. +init = solver.initial_condition() +solution = ivpsolve.solve_adaptive_save_every_step( + vf, init, t0=t0, t1=t1, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm +) + +# Look at the solution +print(f"\ninitial = {jax.tree.map(jnp.shape, init)}") +print(f"\nsolution = {jax.tree.map(jnp.shape, solution)}") diff --git a/docs/getting_started/transitioning_from_other_packages.md b/docs/migration_guide.md similarity index 99% rename from docs/getting_started/transitioning_from_other_packages.md rename to docs/migration_guide.md index 75fac28e1..a5456d1c0 100644 --- a/docs/getting_started/transitioning_from_other_packages.md +++ b/docs/migration_guide.md @@ -1,4 +1,4 @@ -# Transitioning +# Migration guide Here is how you get started with ProbDiffEq for solving ordinary differential equations (ODEs) if you already have experience with other (probabilistic) ODE solver packages in Python and Julia. diff --git a/docs/getting_started/troubleshooting.md b/docs/troubleshooting.md similarity index 100% rename from docs/getting_started/troubleshooting.md rename to docs/troubleshooting.md diff --git a/mkdocs.yml b/mkdocs.yml index 70cf5f755..ac4e0790f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -74,10 +74,10 @@ extra: generator: false nav: - Probabilistic solvers for differential equations in JAX: index.md - - An easy example: examples_quickstart/easy_example.ipynb - - Transitioning from other packages: getting_started/transitioning_from_other_packages.md - - Choosing a solver: getting_started/choosing_a_solver.md - - Troubleshooting: getting_started/troubleshooting.md + - An easy example: examples_quickstart/quickstart.ipynb + - Migration guide: migration_guide.md + - Choosing a solver: choosing_a_solver.md + - Troubleshooting: troubleshooting.md - EXAMPLES | BASIC: - examples_basic/conditioning-on-zero-residual.ipynb - examples_basic/posterior_uncertainties.ipynb @@ -85,10 +85,10 @@ nav: - examples_basic/second_order_problems.ipynb - examples_basic/taylor_coefficients.ipynb - EXAMPLES | ADVANCED: - - examples_advanced/physics_enhanced_regression_1.ipynb - - examples_advanced/physics_enhanced_regression_2.ipynb + - examples_advanced/parameter_estimation_optax.ipynb + - examples_advanced/parameter_estimation_blackjax.ipynb - examples_advanced/neural_ode.ipynb - - examples_advanced/use_equinox_bounded_while_loop.ipynb + - examples_advanced/equinox_while_loop.ipynb - API DOCUMENTATION: - ivpsolve: api_docs/ivpsolve.md - ivpsolvers: api_docs/ivpsolvers.md diff --git a/probdiffeq/impl/_conditional.py b/probdiffeq/impl/_conditional.py index 5f231d84e..c8a391628 100644 --- a/probdiffeq/impl/_conditional.py +++ b/probdiffeq/impl/_conditional.py @@ -17,7 +17,7 @@ class Conditional(containers.NamedTuple): """Conditional distributions.""" - matmul: Array # or anything with a __matmul__ implementation + matmul: Array noise: Any # Usually a random-variable type diff --git a/probdiffeq/impl/_stats.py b/probdiffeq/impl/_stats.py index 596274373..ee9400362 100644 --- a/probdiffeq/impl/_stats.py +++ b/probdiffeq/impl/_stats.py @@ -245,7 +245,7 @@ def qoi_from_sample(self, sample, /): return functools.vmap(self.qoi_from_sample)(sample) return self.unravel(sample) - def update_mean(self, mean, x, /, num): + def update_mean(self, mean, x, /, num): # TODO rename: update_mean_estimate if np.ndim(mean) > 0: assert np.shape(mean) == np.shape(x) return functools.vmap(self.update_mean, in_axes=(0, 0, None))(mean, x, num) diff --git a/probdiffeq/util/doc_util/notebook.py b/probdiffeq/util/doc_util/notebook.py deleted file mode 100644 index 15640358b..000000000 --- a/probdiffeq/util/doc_util/notebook.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Benchmark utils.""" - -import numpy as np -from tueplots import axes, cycler, fontsizes, markers - - -def plot_style(): - colors = ["cornflowerblue", "salmon", "mediumseagreen", "crimson", "darkorchid"] - markers_ = ["o", "v", "P", "^", "X", "d"] - return { - **axes.color(base="black"), - **axes.lines(base_width=0.5), - **axes.tick_direction(x="inout", y="inout"), - **axes.legend(), - **axes.grid(grid_linestyle="dotted"), - **cycler.cycler( - marker=np.tile(markers_, 9)[:15], color=np.tile(colors, 10)[:15] - ), - **markers.with_edge(), - **{"figure.dpi": 100}, - } - - -def plot_sizes(): - return fontsizes.beamer()