@@ -163,17 +163,17 @@ def get_chunked_hist_median(chunked_session_histograms):
163163
164164# TODO: a good test here is to give zero shift for even and off numbered hist and check the output is zero!
165165def compute_histogram_crosscorrelation (
166- session_histogram_list : list [ np .ndarray ] ,
166+ session_histogram_list : np .ndarray ,
167167 non_rigid_windows : np .ndarray ,
168168 num_shifts : int ,
169169 interpolate : bool ,
170170 interp_factor : int ,
171171 kriging_sigma : float ,
172172 kriging_p : float ,
173173 kriging_d : float ,
174- smoothing_sigma_bin : float ,
175- smoothing_sigma_window : float ,
176- ):
174+ smoothing_sigma_bin : None | float ,
175+ smoothing_sigma_window : None | float ,
176+ ) -> tuple [ np . ndarray , np . ndarray ] :
177177 """
178178 Given a list of session activity histograms, cross-correlate
179179 all histograms returning the peak correlation shift (in indices)
@@ -185,7 +185,8 @@ def compute_histogram_crosscorrelation(
185185 Parameters
186186 ----------
187187
188- session_histogram_list : list[np.ndarray]
188+ session_histogram_list : list[np.ndarray] TODO: change name!!
189+ (num_sessions, num_bins) array of session activity histograms.
189190 non_rigid_windows : np.ndarray
190191 A (num windows x num_bins) binary of weights by which to window
191192 the activity histogram for non-rigid-registration. For example, if
@@ -258,23 +259,21 @@ def compute_histogram_crosscorrelation(
258259 """
259260 import matplotlib .pyplot as plt
260261
261- num_sessions = len ( session_histogram_list )
262+ num_sessions = session_histogram_list . shape [ 0 ]
262263 num_bins = session_histogram_list .shape [1 ] # all hists are same length
263264 num_windows = non_rigid_windows .shape [0 ]
264265
265266 shift_matrix = np .zeros ((num_sessions , num_sessions , num_windows ))
266267
267268 center_bin = np .floor ((num_bins * 2 - 1 ) / 2 ).astype (int )
268269
270+ # Create the (num windows, num_bins) matrix for this pair of sessions
271+ num_iter = num_bins * 2 - 1 if not num_shifts else num_shifts * 2
272+ shifts_array = np .arange (- (num_iter // 2 ), num_iter // 2 + 1 )
273+
269274 for i in range (num_sessions ):
270275 for j in range (i , num_sessions ):
271276
272- # Create the (num windows, num_bins) matrix for this pair of sessions
273- num_iter = (
274- num_bins * 2 - 1
275- if not num_shifts
276- else num_shifts * 2 # num_shift_block with iterative alignment is 2x, the same, make note!
277- )
278277 xcorr_matrix = np .zeros ((non_rigid_windows .shape [0 ], num_iter ))
279278
280279 # For each window, window the session histograms (`window` is binary)
@@ -292,12 +291,12 @@ def compute_histogram_crosscorrelation(
292291 window_i = windowed_histogram_i - np .mean (windowed_histogram_i , axis = 1 )[:, np .newaxis ]
293292 window_j = windowed_histogram_j - np .mean (windowed_histogram_j , axis = 1 )[:, np .newaxis ]
294293
295- xcorr = np .zeros (num_iter )
296- for idx , shift in enumerate (range (- num_iter // 2 , num_iter // 2 )):
294+ xcorr = np .zeros (num_iter + 1 )
295+
296+ for idx , shift in enumerate (shifts_array ):
297297 shifted_i = shift_array_fill_zeros (window_i , shift )
298298
299299 xcorr [idx ] = np .correlate (shifted_i .flatten (), window_j .flatten ())
300-
301300 else :
302301 # For a 1D histogram, compute the full cross-correlation and
303302 # window the desired shifts ( this is faster than manual looping).
@@ -315,11 +314,6 @@ def compute_histogram_crosscorrelation(
315314
316315 xcorr_matrix [win_idx , :] = xcorr
317316
318- if num_shifts :
319- shift_center_bin = num_shifts
320- else :
321- shift_center_bin = center_bin
322-
323317 # Smooth the cross-correlations across the bins
324318 if smoothing_sigma_bin :
325319 xcorr_matrix = gaussian_filter (xcorr_matrix , smoothing_sigma_bin , axes = 1 )
@@ -328,36 +322,67 @@ def compute_histogram_crosscorrelation(
328322 if num_windows > 1 and smoothing_sigma_window :
329323 xcorr_matrix = gaussian_filter (xcorr_matrix , smoothing_sigma_window , axes = 0 )
330324
325+ shifts_array = np .arange (- (num_iter // 2 ), num_iter // 2 + 1 ) # TODO: double check
331326 # Upsample the cross-correlation
332327 if interpolate :
333- shifts = np .arange (xcorr_matrix .shape [1 ])
334- shifts_upsampled = np .linspace (shifts [0 ], shifts [- 1 ], shifts .size * interp_factor )
328+
329+ # shifts = np.arange(xcorr_matrix.shape[1])
330+ shifts_upsampled = np .linspace (shifts_array [0 ], shifts_array [- 1 ], shifts_array .size * interp_factor )
335331
336332 K = kriging_kernel (
337- np .c_ [np .ones_like (shifts ), shifts ],
333+ np .c_ [np .ones_like (shifts_array ), shifts_array ],
338334 np .c_ [np .ones_like (shifts_upsampled ), shifts_upsampled ],
339335 kriging_sigma ,
340336 kriging_p ,
341337 kriging_d ,
342338 )
343- xcorr_matrix = np .matmul (xcorr_matrix , K , axes = [(- 2 , - 1 ), (- 2 , - 1 ), (- 2 , - 1 )])
344339
345- xcorr_peak = np .argmax (xcorr_matrix , axis = 1 ) / interp_factor
340+ # breakpoint()
341+
342+ xcorr_matrix_old = np .matmul (xcorr_matrix , K , axes = [(- 2 , - 1 ), (- 2 , - 1 ), (- 2 , - 1 )])
343+ xcorr_matrix_ = np .zeros (
344+ (xcorr_matrix .shape [0 ], shifts_upsampled .size )
345+ ) # TODO: check in nonlinear case
346+ for i_ in range (xcorr_matrix .shape [0 ]):
347+ xcorr_matrix_ [i_ , :] = np .matmul (xcorr_matrix [i_ , :], K )
348+
349+ # breakpoint()
350+
351+ plt .plot (shifts_array , xcorr_matrix .T )
352+ plt .show
353+ plt .plot (shifts_upsampled , xcorr_matrix_ .T )
354+ plt .show ()
355+
356+ xcorr_matrix = xcorr_matrix_
357+
358+ # plt.plot(xcorr_matrix.T)
359+ # plt.plot(xcorr_matrix_old.T)
360+ # plt.show()
361+ #
362+
363+ xcorr_peak = np .argmax (xcorr_matrix , axis = 1 )
364+ shift = shifts_upsampled [xcorr_peak ]
365+
366+ # breakpoint()
367+
346368 else :
347369 xcorr_peak = np .argmax (xcorr_matrix , axis = 1 )
370+ shift = shifts_array [xcorr_peak ]
348371
349- # Caclulate and save the shift for session i to j
350- shift = xcorr_peak - shift_center_bin
372+ # x=i;y= j
373+ # breakpoint()
351374 shift_matrix [i , j , :] = shift
352375
376+ breakpoint ()
377+
353378 # As xcorr shifts are symmetric, the shift matrix is skew symmetric, so fill
354379 # the (empty) lower triangular with the negative (already computed) upper triangular to save computation
355380 for k in range (shift_matrix .shape [2 ]):
356381 lower_i , lower_j = np .tril_indices_from (shift_matrix [:, :, k ], k = - 1 )
357382 upper_i , upper_j = np .triu_indices_from (shift_matrix [:, :, k ], k = 1 )
358383 shift_matrix [lower_i , lower_j , k ] = shift_matrix [upper_i , upper_j , k ] * - 1
359384
360- return shift_matrix
385+ return shift_matrix , xcorr_matrix
361386
362387
363388def shift_array_fill_zeros (array : np .ndarray , shift : int ) -> np .ndarray :
0 commit comments