Skip to content

Commit abe27c1

Browse files
committed
Tidy up tests and (slightly) improve num_iter handling.
1 parent 93b89f6 commit abe27c1

File tree

2 files changed

+142
-192
lines changed

2 files changed

+142
-192
lines changed

src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,10 @@ def compute_histogram_crosscorrelation(
268268
center_bin = np.floor((num_bins * 2 - 1) / 2).astype(int)
269269

270270
# Create the (num windows, num_bins) matrix for this pair of sessions
271-
num_iter = num_bins * 2 - 1 if not num_shifts else num_shifts * 2
272-
shifts_array = np.arange(-(num_iter // 2), num_iter // 2 + 1)
271+
if num_shifts is None:
272+
num_shifts = num_bins - 1
273+
shifts_array = np.arange(-(num_shifts), num_shifts + 1)
274+
num_iter = shifts_array.size
273275

274276
for i in range(num_sessions):
275277
for j in range(i, num_sessions):
@@ -291,7 +293,7 @@ def compute_histogram_crosscorrelation(
291293
window_i = windowed_histogram_i - np.mean(windowed_histogram_i, axis=1)[:, np.newaxis]
292294
window_j = windowed_histogram_j - np.mean(windowed_histogram_j, axis=1)[:, np.newaxis]
293295

294-
xcorr = np.zeros(num_iter + 1)
296+
xcorr = np.zeros(num_iter)
295297

296298
for idx, shift in enumerate(shifts_array):
297299
shifted_i = shift_array_fill_zeros(window_i, shift)
@@ -309,7 +311,7 @@ def compute_histogram_crosscorrelation(
309311
mode="full",
310312
)
311313
if num_shifts:
312-
window_indices = np.arange(center_bin - num_shifts, center_bin + num_shifts)
314+
window_indices = np.arange(center_bin - num_shifts, center_bin + num_shifts + 1)
313315
xcorr = xcorr[window_indices]
314316

315317
xcorr_matrix[win_idx, :] = xcorr
@@ -322,7 +324,6 @@ def compute_histogram_crosscorrelation(
322324
if num_windows > 1 and smoothing_sigma_window:
323325
xcorr_matrix = gaussian_filter(xcorr_matrix, smoothing_sigma_window, axes=0)
324326

325-
shifts_array = np.arange(-(num_iter // 2), num_iter // 2 + 1) # TODO: double check
326327
# Upsample the cross-correlation
327328
if interpolate:
328329

@@ -337,44 +338,16 @@ def compute_histogram_crosscorrelation(
337338
kriging_d,
338339
)
339340

340-
# breakpoint()
341-
342-
xcorr_matrix_old = np.matmul(xcorr_matrix, K, axes=[(-2, -1), (-2, -1), (-2, -1)])
343-
xcorr_matrix_ = np.zeros(
344-
(xcorr_matrix.shape[0], shifts_upsampled.size)
345-
) # TODO: check in nonlinear case
346-
for i_ in range(xcorr_matrix.shape[0]):
347-
xcorr_matrix_[i_, :] = np.matmul(xcorr_matrix[i_, :], K)
348-
349-
# breakpoint()
350-
351-
plt.plot(shifts_array, xcorr_matrix.T)
352-
plt.show
353-
plt.plot(shifts_upsampled, xcorr_matrix_.T)
354-
plt.show()
355-
356-
xcorr_matrix = xcorr_matrix_
357-
358-
# plt.plot(xcorr_matrix.T)
359-
# plt.plot(xcorr_matrix_old.T)
360-
# plt.show()
361-
#
341+
xcorr_matrix = np.matmul(xcorr_matrix, K, axes=[(-2, -1), (-2, -1), (-2, -1)])
362342

363343
xcorr_peak = np.argmax(xcorr_matrix, axis=1)
364344
shift = shifts_upsampled[xcorr_peak]
365-
366-
# breakpoint()
367-
368345
else:
369346
xcorr_peak = np.argmax(xcorr_matrix, axis=1)
370347
shift = shifts_array[xcorr_peak]
371348

372-
# x=i;y=j
373-
# breakpoint()
374349
shift_matrix[i, j, :] = shift
375350

376-
breakpoint()
377-
378351
# As xcorr shifts are symmetric, the shift matrix is skew symmetric, so fill
379352
# the (empty) lower triangular with the negative (already computed) upper triangular to save computation
380353
for k in range(shift_matrix.shape[2]):

0 commit comments

Comments
 (0)