Skip to content

Commit 6924cfd

Browse files
authored
Merge pull request #55 from QTC-UMD/cyrk_flat_fix
Cyrk flat DiffEq fix
2 parents 725b16c + 54810b8 commit 6924cfd

2 files changed

Lines changed: 8 additions & 9 deletions

File tree

docs/source/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Bug Fixes
2222
- Fix `Cell.kappa` and `Cell.eta` to properly introspect `q` when first edge on the graph
2323
happens to not be dipole-allowed.
2424
- Fix `draw_diagram` to handle complex time-dependent functions correctly.
25+
- Fix `cyrk_solve` flat diffEq backend to properly handle EOM stacks.
2526

2627
Deprecations
2728
++++++++++++

src/rydiqule/stack_solvers/cyrk_solver.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,15 @@ def _derEqns(obes_base: np.ndarray, const_base: np.ndarray,
139139
obes_time_r: np.ndarray, const_r: np.ndarray,
140140
obes_time_i: np.ndarray, const_i: np.ndarray,
141141
time_inputs: Sequence[TimeFunc]
142-
) -> Callable[[float, np.ndarray, np.ndarray], None]:
142+
) -> Callable[[np.ndarray, float, np.ndarray], None]:
143143
"""
144-
Function to build the callable passed to CyRK's cyrk_ode cython solver.
144+
Function to build the callable passed to CyRK's pysolve_ivp cython solver.
145145
146146
Note that `time_inputs` functions must be njit compiled.
147147
148148
Uses the base and time matrix components of the eoms to build
149149
a function of vector and scalar time
150-
that has the expected input/output of functions passed to `cyrk.cyrk_ode()`
150+
that has the expected input/output of functions passed to `cyrk.pysolve_ivp()`
151151
"""
152152
import numba as nb
153153

@@ -185,7 +185,7 @@ def _derEqns_flat(obes_base: np.ndarray, const_base: np.ndarray,
185185
obes_time_r: np.ndarray, const_r: np.ndarray,
186186
obes_time_i: np.ndarray, const_i: np.ndarray,
187187
time_inputs: Sequence[TimeFunc]
188-
) -> Callable[[float, np.ndarray, np.ndarray], None]:
188+
) -> Callable[[np.ndarray, float, np.ndarray], None]:
189189
"""
190190
Function to build the callable passed to CyRK's pysolve_ivp cython solver.
191191
@@ -204,7 +204,6 @@ def _derEqns_flat(obes_base: np.ndarray, const_base: np.ndarray,
204204

205205
# basis dimension size
206206
b = obes_base.shape[-1]
207-
b2 = b**2
208207
# time function dimension size
209208
t_func_num = obes_time_r.shape[0]
210209
# flatten eqns arrays
@@ -233,17 +232,16 @@ def func(result_out: np.ndarray, t: float, A_flat: np.ndarray):
233232
result_out[i] += ts[idx].real*const_r[idx, const_time_idx] + ts[idx].imag*const_i[idx, const_time_idx]
234233

235234
for j in range(b):
236-
# define indeces for this step
235+
# define indices for this step
237236
obe_idx = i*b+j
238-
obe_time_idx = obe_idx%b2
239237
A_idx = (i//b)*b+j
240238
# add time-independent obe part
241239
# implements einsum('...ij,...j', obes, A)
242240
result_out[i] += obes_base[obe_idx] * A_flat[A_idx]
243241
for idx in range(t_func_num):
244242
# add time-dependent obe part
245-
result_out[i] += (ts[idx].real*obes_time_r[idx, obe_time_idx]
246-
+ ts[idx].imag*obes_time_i[idx, obe_time_idx]) * A_flat[A_idx]
243+
result_out[i] += (ts[idx].real*obes_time_r[idx, obe_idx]
244+
+ ts[idx].imag*obes_time_i[idx, obe_idx]) * A_flat[A_idx]
247245

248246
return func
249247

0 commit comments

Comments
 (0)