Skip to content

Commit a52b4bc

Browse files
committed
15_08_2025: updated midpoint integrator (newton solver and jacobian)
1 parent 49441ce commit a52b4bc

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

sympint/integrators.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
def midpoint(H:Callable[..., Array],
2121
ns:int=1,
22-
gradient:Optional[Callable[..., Array]] = None) -> Callable[..., Array]:
22+
gradient:Optional[Callable[..., Array]] = None,
23+
jacobian:Optional[Callable[..., Array]] = None,
24+
solve:Optional[Callable[[Array, Array], Array]] = None) -> Callable[..., Array]:
2325
"""
2426
Generate implicit midpoint integrator
2527
@@ -31,6 +33,10 @@ def midpoint(H:Callable[..., Array],
3133
number of Newton iteration steps
3234
gradient: Optional[Callable[..., Array]], default=None
3335
gradient function (defaults to jax.grad)
36+
jacobian: Optional[Callable]
37+
jax.jacfwd or jax.jacrev (default)
38+
solve: Optional[Callable]
39+
linear solver(matrix, vector)
3440
3541
Returns
3642
-------
@@ -39,6 +45,10 @@ def midpoint(H:Callable[..., Array],
3945
4046
"""
4147
gradient = grad if gradient is None else gradient
48+
jacobian = jax.jacrev if jacobian is None else jacobian
49+
if solve is None:
50+
def solve(matrix:Array, vector:Array) -> Array:
51+
return jax.numpy.linalg.solve(matrix, vector)
4252
dHdq = gradient(H, argnums=0)
4353
dHdp = gradient(H, argnums=1)
4454
def integrator(state: Array, dt: Array, t: Array, *args: Array) -> Array:
@@ -53,9 +63,8 @@ def residual(state: Array) -> tuple[Array, Array]:
5363
state = jax.numpy.concatenate([dq, dp])
5464
return state, state
5565
def newton(state: Array) -> Array:
56-
jacobian, error = jax.jacrev(residual, has_aux=True)(state)
57-
delta, *_ = jax.numpy.linalg.lstsq(jacobian, -error)
58-
return state + delta
66+
matrix, error = jacobian(residual, has_aux=True)(state)
67+
return state + solve(matrix, -error)
5968
return nest(ns, newton)(state)
6069
return integrator
6170

0 commit comments

Comments
 (0)