-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtaylor_coefficients.py
More file actions
124 lines (85 loc) · 2.95 KB
/
taylor_coefficients.py
File metadata and controls
124 lines (85 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.15.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
# # Taylor coefficients
#
# To build a probabilistic solver, we need to build a specific state-space model.
# To build this specific state-space model, we interact with Taylor coefficients.
# Here are some examples how Taylor coefficients
# play a role in Probdiffeq's solution routines.
# +
"""Demonstrate how central Taylor coefficient estimation is to Probdiffeq."""
import collections
import jax
import jax.numpy as jnp
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats, taylor
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
jax.config.update("jax_platform_name", "cpu")
# -
# We start by defining an ODE.
# +
f, u0, (t0, t1), f_args = ivps.rigid_body()
def vf(*y, t): # noqa: ARG001
"""Evaluate the vector field."""
return f(*y, *f_args)
# -
# Here is a wrapper arounds Probdiffeq's solution routine.
# +
def solve(tc):
"""Solve the ODE."""
init, prior, ssm = ivpsolvers.prior_wiener_integrated(tc, ssm_fact="dense")
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm)
ts = jnp.linspace(t0, t1, endpoint=True, num=10)
adaptive_solver = ivpsolvers.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm)
return ivpsolve.solve_adaptive_save_at(
vf, init, save_at=ts, adaptive_solver=adaptive_solver, dt0=0.1, ssm=ssm
)
# -
# It's time to solve some ODEs:
# +
tcoeffs = taylor.odejet_padded_scan(lambda *y: vf(*y, t=t0), [u0], num=2)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, solution))
# -
# The type of solution.u matches that of the initial condition.
# +
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution.u))
# -
# Anything that behaves like a list work.
# For example, we can use lists or tuples, but also named tuples.
# +
Taylor = collections.namedtuple("Taylor", ["state", "velocity", "acceleration"])
tcoeffs = Taylor(*tcoeffs)
solution = solve(tcoeffs)
print(jax.tree.map(jnp.shape, tcoeffs))
print(jax.tree.map(jnp.shape, solution))
print(jax.tree.map(jnp.shape, solution.u))
# -
# The same applies to statistical quantities that we can extract from the solution.
# For example, the standard deviation or samples from the solution object:
# +
key = jax.random.PRNGKey(seed=15)
posterior = stats.markov_select_terminal(solution.posterior)
samples, samples_init = stats.markov_sample(
key, posterior, reverse=True, ssm=solution.ssm
)
print(jax.tree.map(jnp.shape, solution.u))
print(jax.tree.map(jnp.shape, solution.u_std))
print(jax.tree.map(jnp.shape, samples))
print(jax.tree.map(jnp.shape, samples_init))
# -