77from spikeinterface .sortingcomponents .peak_detection import detect_peaks
88from spikeinterface .sortingcomponents .peak_localization import localize_peaks
99
10- # TODO: test templates_array_moved are the same with
11- # no shift, with both seed and no seed
12-
13- # rescale units per session
14-
1510
1611class TestSessionDisplacementGenerator :
1712 """
@@ -97,14 +92,14 @@ def test_x_y_rigid_shifts_are_properly_set(self, options):
9792 for unit_idx in range (num_units ):
9893
9994 start_pos = self ._get_peak_chan_loc_in_um (
100- extra_outputs ["template_array_moved " ][0 ][unit_idx ],
95+ extra_outputs ["templates_array_moved " ][0 ][unit_idx ],
10196 options ["y_bin_um" ],
10297 )
10398
10499 for rec_idx in range (1 , options ["num_recs" ]):
105100
106101 new_pos = self ._get_peak_chan_loc_in_um (
107- extra_outputs ["template_array_moved " ][rec_idx ][unit_idx ], options ["y_bin_um" ]
102+ extra_outputs ["templates_array_moved " ][rec_idx ][unit_idx ], options ["y_bin_um" ]
108103 )
109104
110105 y_shift = recording_shifts [rec_idx ][1 ]
@@ -120,7 +115,7 @@ def test_x_y_rigid_shifts_are_properly_set(self, options):
120115 for rec_idx in range (options ["num_recs" ]):
121116 assert np .array_equal (
122117 output_recordings [rec_idx ].templates ,
123- extra_outputs ["template_array_moved " ][rec_idx ],
118+ extra_outputs ["templates_array_moved " ][rec_idx ],
124119 )
125120
126121 def _get_peak_chan_loc_in_um (self , template_array , y_bin_um ):
@@ -275,6 +270,56 @@ def test_displacement_with_peak_detection(self, options):
275270
276271 assert np .isclose (new_pos , first_pos + y_shift , rtol = 0 , atol = options ["y_bin_um" ])
277272
273+ def test_amplitude_scalings (self , options ):
274+
275+ options ["kwargs" ]["recording_durations" ] = (10 , 10 )
276+ options ["kwargs" ]["recording_shifts" ] = ((0 , 0 ), (0 , 0 ))
277+ options ["kwargs" ]["num_units" ] == 5 ,
278+
279+ recording_amplitude_scalings = {
280+ "method" : "by_passed_order" ,
281+ "scalings" : (np .ones (5 ), np .array ([0.1 , 0.2 , 0.3 , 0.4 , 0.5 ])),
282+ }
283+
284+ _ , output_sortings , extra_outputs = generate_session_displacement_recordings (
285+ ** options ["kwargs" ],
286+ recording_amplitude_scalings = recording_amplitude_scalings ,
287+ )
288+ breakpoint ()
289+ first , second = extra_outputs ["templates_array_moved" ] # TODO: own function
290+ first_min = np .min (np .min (first , axis = 2 ), axis = 1 )
291+ second_min = np .min (np .min (second , axis = 2 ), axis = 1 )
292+ scales = second_min / first_min
293+
294+ assert np .allclose (scales , shifts )
295+
296+ # TODO: scale based on recording output
297+ # check scaled by amplitude.
298+
299+ breakpoint ()
300+
301+ def test_metadata (self , options ):
302+ """
303+ Check that metadata required to be set of generated recordings is present
304+ on all output recordings.
305+ """
306+ output_recordings , output_sortings , extra_outputs = generate_session_displacement_recordings (
307+ ** options ["kwargs" ], generate_noise_kwargs = dict (noise_levels = (1.0 , 2.0 ), spatial_decay = 1.0 )
308+ )
309+ num_chans = output_recordings [0 ].get_num_channels ()
310+
311+ for i in range (len (output_recordings )):
312+ assert output_recordings [i ].name == "InterSessionDisplacementRecording"
313+ assert output_recordings [i ]._annotations ["is_filtered" ] is True
314+ assert output_recordings [i ].has_probe ()
315+ assert np .array_equal (output_recordings [i ].get_channel_gains (), np .ones (num_chans ))
316+ assert np .array_equal (output_recordings [i ].get_channel_offsets (), np .zeros (num_chans ))
317+
318+ assert np .array_equal (
319+ output_sortings [i ].get_property ("gt_unit_locations" ), extra_outputs ["unit_locations" ][i ]
320+ )
321+ assert output_sortings [i ].name == "InterSessionDisplacementSorting"
322+
278323 def test_same_as_generate_ground_truth_recording (self ):
279324 """
280325 It is expected that inter-session displacement randomly
@@ -302,7 +347,7 @@ def test_same_as_generate_ground_truth_recording(self):
302347 no_shift_recording , _ = generate_session_displacement_recordings (
303348 num_units = num_units ,
304349 recording_durations = [duration ],
305- recording_shifts = ((0 , 0 )),
350+ recording_shifts = ((0 , 0 ), ),
306351 sampling_frequency = sampling_frequency ,
307352 probe_name = probe_name ,
308353 generate_probe_kwargs = generate_probe_kwargs ,
0 commit comments