Skip to content

Commit 0b68ada

Browse files
committed
07_17_2025: added option to pass gradient (default is jax.grad)
1 parent 1f1a15d commit 0b68ada

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

sympint/integrators.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,20 @@
66
77
"""
88
from typing import Callable
9+
from typing import Optional
910

1011
import jax
1112
from jax import Array
13+
from jax import grad
1214

1315
from sympint.functional import nest
1416
from sympint.functional import fold
1517

1618
from sympint.yoshida import sequence
1719

18-
def midpoint(H:Callable[..., Array], ns:int=1) -> Callable[..., Array]:
20+
def midpoint(H:Callable[..., Array],
21+
ns:int=1,
22+
gradient:Optional[Callable[..., Array]] = None) -> Callable[..., Array]:
1923
"""
2024
Generate implicit midpoint integrator
2125
@@ -25,15 +29,18 @@ def midpoint(H:Callable[..., Array], ns:int=1) -> Callable[..., Array]:
2529
Hamiltonian function H(q, p, dt, t, *args)
2630
ns: int, default=1
2731
number of Newton iteration steps
32+
gradient: Optional[Callable[..., Array]], default=None
33+
gradient function (defaults to jax.grad)
2834
2935
Returns
3036
-------
3137
Callable[[Array, *Any], Array]
3238
integrator(qp, dt, t, *args)
3339
3440
"""
35-
dHdq = jax.grad(H, argnums=0)
36-
dHdp = jax.grad(H, argnums=1)
41+
gradient = grad if gradient is None else gradient
42+
dHdq = gradient(H, argnums=0)
43+
dHdp = gradient(H, argnums=1)
3744
def integrator(state: Array, dt: Array, t: Array, *args: Array) -> Array:
3845
q, p = jax.numpy.reshape(state, (2, -1))
3946
t_m = t + 0.5*dt
@@ -53,7 +60,9 @@ def newton(state: Array) -> Array:
5360
return integrator
5461

5562

56-
def tao(H:Callable[..., Array], binding:float=0.0) -> Callable[..., Array]:
63+
def tao(H:Callable[..., Array],
64+
binding:float=0.0,
65+
gradient:Optional[Callable[..., Array]] = None) -> Callable[..., Array]:
5766
"""
5867
Generate Tao integrator
5968
@@ -63,15 +72,18 @@ def tao(H:Callable[..., Array], binding:float=0.0) -> Callable[..., Array]:
6372
Hamiltonian function H(q, p, dt, *args)
6473
binding: float, default=0.0
6574
binding factor
75+
gradient: Optional[Callable[..., Array]], default=None
76+
gradient function (defaults to jax.grad)
6677
6778
Returns
6879
-------
6980
Callable[[Array, *Any], Array]
7081
integrator(qp, dt, *args)
7182
7283
"""
73-
dHdq = jax.grad(H, argnums=0)
74-
dHdp = jax.grad(H, argnums=1)
84+
gradient = grad if gradient is None else gradient
85+
dHdq = gradient(H, argnums=0)
86+
dHdp = gradient(H, argnums=1)
7587
def fa(state:Array, dt:Array, *args:Array) -> Array:
7688
q, p, Q, P = state.reshape(4, -1)
7789
return jax.numpy.concatenate([q, p - dt*dHdq(q, P, *args), Q + dt*dHdp(q, P, *args), P])

sympint/yoshida.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def sequence(ni:int,
142142
Output sequence mappings have (x, dt, *args) singnatures
143143
144144
"""
145-
indices, weights = table(len(mappings), ni, nf, merge)
145+
indices, weights, *_ = table(len(mappings), ni, nf, merge)
146146
parameters = [[] for _ in range(len(mappings))] if parameters is None else parameters
147147
parameters = [parameters[i] for i in indices]
148148
def wrapper(mapping, weight, parameter):

0 commit comments

Comments
 (0)