@@ -139,15 +139,15 @@ def _derEqns(obes_base: np.ndarray, const_base: np.ndarray,
139139 obes_time_r : np .ndarray , const_r : np .ndarray ,
140140 obes_time_i : np .ndarray , const_i : np .ndarray ,
141141 time_inputs : Sequence [TimeFunc ]
142- ) -> Callable [[float , np .ndarray , np .ndarray ], None ]:
142+ ) -> Callable [[np .ndarray , float , np .ndarray ], None ]:
143143 """
144- Function to build the callable passed to CyRK's cyrk_ode cython solver.
144+ Function to build the callable passed to CyRK's pysolve_ivp cython solver.
145145
146146 Note that `time_inputs` functions must be njit compiled.
147147
148148 Uses the base and time matrix components of the eoms to build
149149 a function of vector and scalar time
150- that has the expected input/output of functions passed to `cyrk.cyrk_ode ()`
150+ that has the expected input/output of functions passed to `cyrk.pysolve_ivp ()`
151151 """
152152 import numba as nb
153153
@@ -185,7 +185,7 @@ def _derEqns_flat(obes_base: np.ndarray, const_base: np.ndarray,
185185 obes_time_r : np .ndarray , const_r : np .ndarray ,
186186 obes_time_i : np .ndarray , const_i : np .ndarray ,
187187 time_inputs : Sequence [TimeFunc ]
188- ) -> Callable [[float , np .ndarray , np .ndarray ], None ]:
188+ ) -> Callable [[np .ndarray , float , np .ndarray ], None ]:
189189 """
190190 Function to build the callable passed to CyRK's pysolve_ivp cython solver.
191191
@@ -204,7 +204,6 @@ def _derEqns_flat(obes_base: np.ndarray, const_base: np.ndarray,
204204
205205 # basis dimension size
206206 b = obes_base .shape [- 1 ]
207- b2 = b ** 2
208207 # time function dimension size
209208 t_func_num = obes_time_r .shape [0 ]
210209 # flatten eqns arrays
@@ -233,17 +232,16 @@ def func(result_out: np.ndarray, t: float, A_flat: np.ndarray):
233232 result_out [i ] += ts [idx ].real * const_r [idx , const_time_idx ] + ts [idx ].imag * const_i [idx , const_time_idx ]
234233
235234 for j in range (b ):
236- # define indeces for this step
235+ # define indices for this step
237236 obe_idx = i * b + j
238- obe_time_idx = obe_idx % b2
239237 A_idx = (i // b )* b + j
240238 # add time-independent obe part
241239 # implements einsum('...ij,...j', obes, A)
242240 result_out [i ] += obes_base [obe_idx ] * A_flat [A_idx ]
243241 for idx in range (t_func_num ):
244242 # add time-dependent obe part
245- result_out [i ] += (ts [idx ].real * obes_time_r [idx , obe_time_idx ]
246- + ts [idx ].imag * obes_time_i [idx , obe_time_idx ]) * A_flat [A_idx ]
243+ result_out [i ] += (ts [idx ].real * obes_time_r [idx , obe_idx ]
244+ + ts [idx ].imag * obes_time_i [idx , obe_idx ]) * A_flat [A_idx ]
247245
248246 return func
249247
0 commit comments