1919
2020def midpoint (H :Callable [..., Array ],
2121 ns :int = 1 ,
22- gradient :Optional [Callable [..., Array ]] = None ) -> Callable [..., Array ]:
22+ gradient :Optional [Callable [..., Array ]] = None ,
23+ jacobian :Optional [Callable [..., Array ]] = None ,
24+ solve :Optional [Callable [[Array , Array ], Array ]] = None ) -> Callable [..., Array ]:
2325 """
2426 Generate implicit midpoint integrator
2527
@@ -31,6 +33,10 @@ def midpoint(H:Callable[..., Array],
3133 number of Newton iteration steps
3234 gradient: Optional[Callable[..., Array]], default=None
3335 gradient function (defaults to jax.grad)
36+ jacobian: Optional[Callable]
37+ jax.jacfwd or jax.jacrev (default)
38+ solve: Optional[Callable]
39+ linear solver(matrix, vector)
3440
3541 Returns
3642 -------
@@ -39,6 +45,10 @@ def midpoint(H:Callable[..., Array],
3945
4046 """
4147 gradient = grad if gradient is None else gradient
48+ jacobian = jax .jacrev if jacobian is None else jacobian
49+ if solve is None :
50+ def solve (matrix :Array , vector :Array ) -> Array :
51+ return jax .numpy .linalg .solve (matrix , vector )
4252 dHdq = gradient (H , argnums = 0 )
4353 dHdp = gradient (H , argnums = 1 )
4454 def integrator (state : Array , dt : Array , t : Array , * args : Array ) -> Array :
@@ -53,9 +63,8 @@ def residual(state: Array) -> tuple[Array, Array]:
5363 state = jax .numpy .concatenate ([dq , dp ])
5464 return state , state
5565 def newton (state : Array ) -> Array :
56- jacobian , error = jax .jacrev (residual , has_aux = True )(state )
57- delta , * _ = jax .numpy .linalg .lstsq (jacobian , - error )
58- return state + delta
66+ matrix , error = jacobian (residual , has_aux = True )(state )
67+ return state + solve (matrix , - error )
5968 return nest (ns , newton )(state )
6069 return integrator
6170
0 commit comments