Skip to content

Commit 0481dca

Browse files
committed
12_03_2025: added symplectic implicit midpoint integrator
1 parent 58b854e commit 0481dca

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

sympint/integrator.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""
2+
Integrator
3+
----------
4+
5+
Collection of symplectic (JAX composable) integrators
6+
7+
"""
8+
from typing import Callable
9+
10+
import jax
11+
from jax import Array
12+
13+
from sympint.functional import nest
14+
15+
def midpoint(H:Callable[..., Array], ns:int=1) -> Callable[..., Array]:
16+
"""
17+
Generate implicit midpoint integrator
18+
19+
Parameters
20+
----------
21+
H: Callable[[Array, *Any], Array]
22+
Hamiltonian function H(q, p, dt, t, *args)
23+
ns: int, default=1
24+
number of Newton iteration steps
25+
26+
Returns
27+
-------
28+
Callable[[Array, *Any], Array]
29+
integrator(qp, dt, t, *args)
30+
31+
"""
32+
dHdq = jax.grad(H, argnums=0)
33+
dHdp = jax.grad(H, argnums=1)
34+
def integrator(state: Array, dt: Array, t: Array, *args: Array) -> Array:
35+
q, p = jax.numpy.reshape(state, (2, -1))
36+
t_m = t + 0.5*dt
37+
def residual(state: Array) -> tuple[Array, Array]:
38+
Q, P = jax.numpy.reshape(state, (2, -1))
39+
q_m = 0.5*(q + Q)
40+
p_m = 0.5*(p + P)
41+
dq = Q - q - dt*dHdp(q_m, p_m, t_m, *args)
42+
dp = P - p + dt*dHdq(q_m, p_m, t_m, *args)
43+
state = jax.numpy.concatenate([dq, dp])
44+
return state, state
45+
def newton(state: Array) -> Array:
46+
jacobian, error = jax.jacrev(residual, has_aux=True)(state)
47+
delta, *_ = jax.numpy.linalg.lstsq(jacobian, -error)
48+
return state + delta
49+
return nest(ns, newton)(state)
50+
return integrator

0 commit comments

Comments
 (0)