66
77"""
88from typing import Callable
9+ from typing import Optional
910
1011import jax
1112from jax import Array
13+ from jax import grad
1214
1315from sympint .functional import nest
1416from sympint .functional import fold
1517
1618from sympint .yoshida import sequence
1719
18- def midpoint (H :Callable [..., Array ], ns :int = 1 ) -> Callable [..., Array ]:
20+ def midpoint (H :Callable [..., Array ],
21+ ns :int = 1 ,
22+ gradient :Optional [Callable [..., Array ]] = None ) -> Callable [..., Array ]:
1923 """
2024 Generate implicit midpoint integrator
2125
@@ -25,15 +29,18 @@ def midpoint(H:Callable[..., Array], ns:int=1) -> Callable[..., Array]:
2529 Hamiltonian function H(q, p, dt, t, *args)
2630 ns: int, default=1
2731 number of Newton iteration steps
32+ gradient: Optional[Callable[..., Array]], default=None
33+ gradient function (defaults to jax.grad)
2834
2935 Returns
3036 -------
3137 Callable[[Array, *Any], Array]
3238 integrator(qp, dt, t, *args)
3339
3440 """
35- dHdq = jax .grad (H , argnums = 0 )
36- dHdp = jax .grad (H , argnums = 1 )
41+ gradient = grad if gradient is None else gradient
42+ dHdq = gradient (H , argnums = 0 )
43+ dHdp = gradient (H , argnums = 1 )
3744 def integrator (state : Array , dt : Array , t : Array , * args : Array ) -> Array :
3845 q , p = jax .numpy .reshape (state , (2 , - 1 ))
3946 t_m = t + 0.5 * dt
@@ -53,7 +60,9 @@ def newton(state: Array) -> Array:
5360 return integrator
5461
5562
56- def tao (H :Callable [..., Array ], binding :float = 0.0 ) -> Callable [..., Array ]:
63+ def tao (H :Callable [..., Array ],
64+ binding :float = 0.0 ,
65+ gradient :Optional [Callable [..., Array ]] = None ) -> Callable [..., Array ]:
5766 """
5867 Generate Tao integrator
5968
@@ -63,15 +72,18 @@ def tao(H:Callable[..., Array], binding:float=0.0) -> Callable[..., Array]:
6372 Hamiltonian function H(q, p, dt, *args)
6473 binding: float, default=0.0
6574 binding factor
75+ gradient: Optional[Callable[..., Array]], default=None
76+ gradient function (defaults to jax.grad)
6677
6778 Returns
6879 -------
6980 Callable[[Array, *Any], Array]
7081 integrator(qp, dt, *args)
7182
7283 """
73- dHdq = jax .grad (H , argnums = 0 )
74- dHdp = jax .grad (H , argnums = 1 )
84+ gradient = grad if gradient is None else gradient
85+ dHdq = gradient (H , argnums = 0 )
86+ dHdp = gradient (H , argnums = 1 )
7587 def fa (state :Array , dt :Array , * args :Array ) -> Array :
7688 q , p , Q , P = state .reshape (4 , - 1 )
7789 return jax .numpy .concatenate ([q , p - dt * dHdq (q , P , * args ), Q + dt * dHdp (q , P , * args ), P ])
0 commit comments