@@ -268,8 +268,10 @@ def compute_histogram_crosscorrelation(
268268 center_bin = np .floor ((num_bins * 2 - 1 ) / 2 ).astype (int )
269269
270270 # Create the (num windows, num_bins) matrix for this pair of sessions
271- num_iter = num_bins * 2 - 1 if not num_shifts else num_shifts * 2
272- shifts_array = np .arange (- (num_iter // 2 ), num_iter // 2 + 1 )
271+ if num_shifts is None :
272+ num_shifts = num_bins - 1
273+ shifts_array = np .arange (- (num_shifts ), num_shifts + 1 )
274+ num_iter = shifts_array .size
273275
274276 for i in range (num_sessions ):
275277 for j in range (i , num_sessions ):
@@ -291,7 +293,7 @@ def compute_histogram_crosscorrelation(
291293 window_i = windowed_histogram_i - np .mean (windowed_histogram_i , axis = 1 )[:, np .newaxis ]
292294 window_j = windowed_histogram_j - np .mean (windowed_histogram_j , axis = 1 )[:, np .newaxis ]
293295
294- xcorr = np .zeros (num_iter + 1 )
296+ xcorr = np .zeros (num_iter )
295297
296298 for idx , shift in enumerate (shifts_array ):
297299 shifted_i = shift_array_fill_zeros (window_i , shift )
@@ -309,7 +311,7 @@ def compute_histogram_crosscorrelation(
309311 mode = "full" ,
310312 )
311313 if num_shifts :
312- window_indices = np .arange (center_bin - num_shifts , center_bin + num_shifts )
314+ window_indices = np .arange (center_bin - num_shifts , center_bin + num_shifts + 1 )
313315 xcorr = xcorr [window_indices ]
314316
315317 xcorr_matrix [win_idx , :] = xcorr
@@ -322,7 +324,6 @@ def compute_histogram_crosscorrelation(
322324 if num_windows > 1 and smoothing_sigma_window :
323325 xcorr_matrix = gaussian_filter (xcorr_matrix , smoothing_sigma_window , axes = 0 )
324326
325- shifts_array = np .arange (- (num_iter // 2 ), num_iter // 2 + 1 ) # TODO: double check
326327 # Upsample the cross-correlation
327328 if interpolate :
328329
@@ -337,44 +338,16 @@ def compute_histogram_crosscorrelation(
337338 kriging_d ,
338339 )
339340
340- # breakpoint()
341-
342- xcorr_matrix_old = np .matmul (xcorr_matrix , K , axes = [(- 2 , - 1 ), (- 2 , - 1 ), (- 2 , - 1 )])
343- xcorr_matrix_ = np .zeros (
344- (xcorr_matrix .shape [0 ], shifts_upsampled .size )
345- ) # TODO: check in nonlinear case
346- for i_ in range (xcorr_matrix .shape [0 ]):
347- xcorr_matrix_ [i_ , :] = np .matmul (xcorr_matrix [i_ , :], K )
348-
349- # breakpoint()
350-
351- plt .plot (shifts_array , xcorr_matrix .T )
352- plt .show
353- plt .plot (shifts_upsampled , xcorr_matrix_ .T )
354- plt .show ()
355-
356- xcorr_matrix = xcorr_matrix_
357-
358- # plt.plot(xcorr_matrix.T)
359- # plt.plot(xcorr_matrix_old.T)
360- # plt.show()
361- #
341+ xcorr_matrix = np .matmul (xcorr_matrix , K , axes = [(- 2 , - 1 ), (- 2 , - 1 ), (- 2 , - 1 )])
362342
363343 xcorr_peak = np .argmax (xcorr_matrix , axis = 1 )
364344 shift = shifts_upsampled [xcorr_peak ]
365-
366- # breakpoint()
367-
368345 else :
369346 xcorr_peak = np .argmax (xcorr_matrix , axis = 1 )
370347 shift = shifts_array [xcorr_peak ]
371348
372- # x=i;y=j
373- # breakpoint()
374349 shift_matrix [i , j , :] = shift
375350
376- breakpoint ()
377-
378351 # As xcorr shifts are symmetric, the shift matrix is skew symmetric, so fill
379352 # the (empty) lower triangular with the negative (already computed) upper triangular to save computation
380353 for k in range (shift_matrix .shape [2 ]):
0 commit comments