1111from jax import Array
1212
1313from sympint .functional import nest
14+ from sympint .functional import fold
15+
16+ from sympint .yoshida import sequence
1417
1518def midpoint (H :Callable [..., Array ], ns :int = 1 ) -> Callable [..., Array ]:
1619 """
@@ -47,4 +50,50 @@ def newton(state: Array) -> Array:
4750 delta , * _ = jax .numpy .linalg .lstsq (jacobian , - error )
4851 return state + delta
4952 return nest (ns , newton )(state )
50- return integrator
53+ return integrator
54+
55+
56+ def tao (H :Callable [..., Array ], binding :float = 0.0 ) -> Callable [..., Array ]:
57+ """
58+ Generate Tao integrator
59+
60+ Parameters
61+ ----------
62+ H: Callable[[Array, *Any], Array]
63+ Hamiltonian function H(q, p, dt, *args)
64+ binding: float, default=0.0
65+ binding factor
66+
67+ Returns
68+ -------
69+ Callable[[Array, *Any], Array]
70+ integrator(qp, dt, *args)
71+
72+ """
73+ dHdq = jax .grad (H , argnums = 0 )
74+ dHdp = jax .grad (H , argnums = 1 )
75+ def fa (state :Array , dt :Array , * args :Array ) -> Array :
76+ q , p , Q , P = state .reshape (4 , - 1 )
77+ return jax .numpy .concatenate ([q , p - dt * dHdq (q , P , * args ), Q + dt * dHdp (q , P , * args ), P ])
78+ def fb (state :Array , dt :Array , * args :Array ) -> Array :
79+ q , p , Q , P = state .reshape (4 , - 1 )
80+ return jax .numpy .concatenate ([q + dt * dHdp (Q , p , * args ), p , Q , P - dt * dHdq (Q , p , * args )])
81+ def fc (state :Array , dt :Array , * args :Array ) -> Array :
82+ q , p , Q , P = state .reshape (4 , - 1 )
83+ omega = 2 * binding * dt
84+ cos = jax .numpy .cos (omega )
85+ sin = jax .numpy .sin (omega )
86+ dq = q - Q
87+ dp = p - P
88+ return jax .numpy .concatenate ([
89+ 0.5 * (q + Q + cos * dq + sin * dp ),
90+ 0.5 * (p + P - sin * dq + cos * dp ),
91+ 0.5 * (q + q - cos * dq - sin * dp ),
92+ 0.5 * (p + P + sin * dq - cos * dp )
93+ ])
94+ step = fold (sequence (0 , 0 , [fa , fb , fc ] if binding != 0.0 else [fa , fb ], merge = True ))
95+ def integrator (state :Array , dt :Array , * args :Array ) -> Array :
96+ local = step (jax .numpy .concatenate ([state , state ]), dt , * args )
97+ q , p , * _ = local .reshape (4 , - 1 )
98+ return jax .numpy .concatenate ([q , p ])
99+ return integrator
0 commit comments