-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathequinox_while_loop.py
More file actions
104 lines (69 loc) · 2.47 KB
/
equinox_while_loop.py
File metadata and controls
104 lines (69 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.15.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
# # Equinox's while-loops
#
# Use [Equinox's](https://docs.kidger.site/equinox/)
# bounded while loop to enable reverse-mode differentiation of adaptive IVP solvers.
# +
"""Use Equinox's while loop to compute gradients of `simulate_terminal_values`."""
import equinox
import jax
import jax.numpy as jnp
from probdiffeq import ivpsolve, probdiffeq, taylor
def solution_routine(while_loop):
"""Construct a parameter-to-solution function and an initial value."""
@jax.jit
def vf(y, *, t): # noqa: ARG001
"""Evaluate the vector field."""
return 0.5 * y * (1 - y)
t0, t1 = 0.0, 1.0
u0 = jnp.asarray([0.1])
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
init, ibm, ssm = probdiffeq.prior_wiener_integrated(tcoeffs, ssm_fact="isotropic")
ts0 = probdiffeq.constraint_ode_ts0(ode_order=1, ssm=ssm)
strategy = probdiffeq.strategy_smoother_fixedpoint(ssm=ssm)
solver = probdiffeq.solver(
vf, strategy=strategy, prior=ibm, constraint=ts0, ssm=ssm
)
errorest = probdiffeq.errorest_local_residual_cached(prior=ibm, ssm=ssm)
solve_adaptive = ivpsolve.solve_adaptive_terminal_values(
solver=solver, errorest=errorest, while_loop=while_loop
)
def simulate(init_val):
"""Evaluate the parameter-to-solution function."""
sol = solve_adaptive(init_val, t0=t0, t1=t1, atol=1e-3, rtol=1e-3)
# Any scalar function of the IVP solution would do
# Try the log-marginal-likelihood losses (see the other tutorials).
return jnp.dot(sol.u.mean[0], sol.u.mean[0])
return simulate, init
# -
# This is the default behaviour.
# +
solve, x = solution_routine(jax.lax.while_loop)
try:
solution, gradient = jax.jit(jax.value_and_grad(solve))(x)
except ValueError as err:
print(f"Caught error:\n\t {err}")
# -
# This while-loop makes the solver differentiable
# +
def while_loop_func(*a, **kw):
"""Evaluate a bounded while loop."""
return equinox.internal.while_loop(*a, **kw, kind="bounded", max_steps=100)
solve, x = solution_routine(while_loop=while_loop_func)
# Compute gradients
solution, gradient = jax.jit(jax.value_and_grad(solve))(x)
print(solution)
print(gradient)
# -