1+ """
2+ Integrator
3+ ----------
4+
5+ Collection of symplectic (JAX composable) integrators
6+
7+ """
8+ from typing import Callable
9+
10+ import jax
11+ from jax import Array
12+
13+ from sympint .functional import nest
14+
15+ def midpoint (H :Callable [..., Array ], ns :int = 1 ) -> Callable [..., Array ]:
16+ """
17+ Generate implicit midpoint integrator
18+
19+ Parameters
20+ ----------
21+ H: Callable[[Array, *Any], Array]
22+ Hamiltonian function H(q, p, dt, t, *args)
23+ ns: int, default=1
24+ number of Newton iteration steps
25+
26+ Returns
27+ -------
28+ Callable[[Array, *Any], Array]
29+ integrator(qp, dt, t, *args)
30+
31+ """
32+ dHdq = jax .grad (H , argnums = 0 )
33+ dHdp = jax .grad (H , argnums = 1 )
34+ def integrator (state : Array , dt : Array , t : Array , * args : Array ) -> Array :
35+ q , p = jax .numpy .reshape (state , (2 , - 1 ))
36+ t_m = t + 0.5 * dt
37+ def residual (state : Array ) -> tuple [Array , Array ]:
38+ Q , P = jax .numpy .reshape (state , (2 , - 1 ))
39+ q_m = 0.5 * (q + Q )
40+ p_m = 0.5 * (p + P )
41+ dq = Q - q - dt * dHdp (q_m , p_m , t_m , * args )
42+ dp = P - p + dt * dHdq (q_m , p_m , t_m , * args )
43+ state = jax .numpy .concatenate ([dq , dp ])
44+ return state , state
45+ def newton (state : Array ) -> Array :
46+ jacobian , error = jax .jacrev (residual , has_aux = True )(state )
47+ delta , * _ = jax .numpy .linalg .lstsq (jacobian , - error )
48+ return state + delta
49+ return nest (ns , newton )(state )
50+ return integrator
0 commit comments