Skip to content

Commit c0b2d4b

Browse files
committed
13_03_2025: added symplectic tao integrator for non-separatable hamiltonians
1 parent a166a5f commit c0b2d4b

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

sympint/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
'coefficients',
1414
'table',
1515
'sequence',
16-
'midpoint'
16+
'midpoint',
17+
'tao'
1718
]
1819

1920
from sympint.functional import nest
@@ -26,4 +27,5 @@
2627
from sympint.yoshida import table
2728
from sympint.yoshida import sequence
2829

29-
from sympint.integrators import midpoint
30+
from sympint.integrators import midpoint
31+
from sympint.integrators import tao

sympint/integrators.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from jax import Array
1212

1313
from sympint.functional import nest
14+
from sympint.functional import fold
15+
16+
from sympint.yoshida import sequence
1417

1518
def midpoint(H:Callable[..., Array], ns:int=1) -> Callable[..., Array]:
1619
"""
@@ -47,4 +50,50 @@ def newton(state: Array) -> Array:
4750
delta, *_ = jax.numpy.linalg.lstsq(jacobian, -error)
4851
return state + delta
4952
return nest(ns, newton)(state)
50-
return integrator
53+
return integrator
54+
55+
56+
def tao(H:Callable[..., Array], binding:float=0.0) -> Callable[..., Array]:
57+
"""
58+
Generate Tao integrator
59+
60+
Parameters
61+
----------
62+
H: Callable[[Array, *Any], Array]
63+
Hamiltonian function H(q, p, dt, *args)
64+
binding: float, default=0.0
65+
binding factor
66+
67+
Returns
68+
-------
69+
Callable[[Array, *Any], Array]
70+
integrator(qp, dt, *args)
71+
72+
"""
73+
dHdq = jax.grad(H, argnums=0)
74+
dHdp = jax.grad(H, argnums=1)
75+
def fa(state:Array, dt:Array, *args:Array) -> Array:
76+
q, p, Q, P = state.reshape(4, -1)
77+
return jax.numpy.concatenate([q, p - dt*dHdq(q, P, *args), Q + dt*dHdp(q, P, *args), P])
78+
def fb(state:Array, dt:Array, *args:Array) -> Array:
79+
q, p, Q, P = state.reshape(4, -1)
80+
return jax.numpy.concatenate([q + dt*dHdp(Q, p, *args), p, Q, P - dt*dHdq(Q, p, *args)])
81+
def fc(state:Array, dt:Array, *args:Array) -> Array:
82+
q, p, Q, P = state.reshape(4, -1)
83+
omega = 2*binding*dt
84+
cos = jax.numpy.cos(omega)
85+
sin = jax.numpy.sin(omega)
86+
dq = q - Q
87+
dp = p - P
88+
return jax.numpy.concatenate([
89+
0.5*(q + Q + cos*dq + sin*dp),
90+
0.5*(p + P - sin*dq + cos*dp),
91+
0.5*(q + q - cos*dq - sin*dp),
92+
0.5*(p + P + sin*dq - cos*dp)
93+
])
94+
step = fold(sequence(0, 0, [fa, fb, fc] if binding != 0.0 else [fa, fb], merge=True))
95+
def integrator(state:Array, dt:Array, *args:Array) -> Array:
96+
local = step(jax.numpy.concatenate([state, state]), dt, *args)
97+
q, p, *_ = local.reshape(4, -1)
98+
return jax.numpy.concatenate([q, p])
99+
return integrator

0 commit comments

Comments
 (0)