|
3 | 3 | from pathlib import Path |
4 | 4 |
|
5 | 5 | import numpy as np |
| 6 | +import torch |
6 | 7 |
|
7 | 8 | from kilosort import io |
8 | 9 |
|
@@ -231,3 +232,42 @@ def test_tmax_only(torch_device, data_directory): |
231 | 232 | finally: |
232 | 233 | # Delete memmap file and re-raise exception |
233 | 234 | path.unlink() |
| 235 | + |
| 236 | + |
| 237 | +def test_file_group(torch_device, data_directory, bfile): |
| 238 | + file = data_directory / 'ZFM-02370_mini.imec0.ap.short.bin' |
| 239 | + fs = bfile.fs |
| 240 | + n_chans = bfile.n_chan_bin |
| 241 | + |
| 242 | + # Test with file_objects option |
| 243 | + # Load as three 15-second files instead of one 45-second file. |
| 244 | + objs = [np.memmap(file, dtype='int16', shape=bfile.shape, mode='r') |
| 245 | + for _ in range(3)] |
| 246 | + objs[0] = objs[0][:int(15*fs),:] |
| 247 | + objs[1] = objs[1][int(15*fs):int(30*fs),:] |
| 248 | + objs[2] = objs[2][int(30*fs):,:] |
| 249 | + bfg = io.BinaryFileGroup(file_objects=objs) |
| 250 | + bfile2 = io.BinaryFiltered( |
| 251 | + filename='test', n_chan_bin=n_chans, fs=fs, chan_map=bfile.chan_map, |
| 252 | + device=torch_device, file_object=bfg, dtype='int16' |
| 253 | + ) |
| 254 | + |
| 255 | + # First batch, overlapping batch, and last batch |
| 256 | + # (assumes 45s test dataset with 2s batch size) |
| 257 | + for i in [0, 7, 22]: |
| 258 | + b1 = bfile.padded_batch_to_torch(i, skip_preproc=True) |
| 259 | + b2 = bfile2.padded_batch_to_torch(i) |
| 260 | + assert torch.allclose(b1, b2) |
| 261 | + |
| 262 | + # Test with filenames option |
| 263 | + files = [file]*3 # Load the same data three times |
| 264 | + bfile3 = io.BinaryFiltered( |
| 265 | + filename=files, n_chan_bin=n_chans, fs=fs, chan_map=bfile.chan_map, |
| 266 | + device=torch_device, dtype='int16' |
| 267 | + ) |
| 268 | + |
| 269 | + # First and first, last and last, last of original and last of concat |
| 270 | + for i,j in [(0,0), (21,21), (22,67)]: |
| 271 | + b1 = bfile.padded_batch_to_torch(i, skip_preproc=True) |
| 272 | + b2 = bfile3.padded_batch_to_torch(j) |
| 273 | + assert torch.allclose(b1, b2) |
0 commit comments