Skip to content

Commit 9516284

Browse files
author
Jacob Pennington
committed
Added tests for BinaryFileGroup
1 parent e651523 commit 9516284

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

tests/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,5 +221,3 @@ def bfile(saved_ops, torch_device, data_directory):
221221
bfile = io.bfile_from_ops(ops, filename=filename, device=torch_device)
222222

223223
return bfile
224-
225-
### End

tests/test_io.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44

55
import numpy as np
6+
import torch
67

78
from kilosort import io
89

@@ -231,3 +232,42 @@ def test_tmax_only(torch_device, data_directory):
231232
finally:
232233
# Delete memmap file and re-raise exception
233234
path.unlink()
235+
236+
237+
def test_file_group(torch_device, data_directory, bfile):
238+
file = data_directory / 'ZFM-02370_mini.imec0.ap.short.bin'
239+
fs = bfile.fs
240+
n_chans = bfile.n_chan_bin
241+
242+
# Test with file_objects option
243+
# Load as three 15-second files instead of one 45-second file.
244+
objs = [np.memmap(file, dtype='int16', shape=bfile.shape, mode='r')
245+
for _ in range(3)]
246+
objs[0] = objs[0][:int(15*fs),:]
247+
objs[1] = objs[1][int(15*fs):int(30*fs),:]
248+
objs[2] = objs[2][int(30*fs):,:]
249+
bfg = io.BinaryFileGroup(file_objects=objs)
250+
bfile2 = io.BinaryFiltered(
251+
filename='test', n_chan_bin=n_chans, fs=fs, chan_map=bfile.chan_map,
252+
device=torch_device, file_object=bfg, dtype='int16'
253+
)
254+
255+
# First batch, overlapping batch, and last batch
256+
# (assumes 45s test dataset with 2s batch size)
257+
for i in [0, 7, 22]:
258+
b1 = bfile.padded_batch_to_torch(i, skip_preproc=True)
259+
b2 = bfile2.padded_batch_to_torch(i)
260+
assert torch.allclose(b1, b2)
261+
262+
# Test with filenames option
263+
files = [file]*3 # Load the same data three times
264+
bfile3 = io.BinaryFiltered(
265+
filename=files, n_chan_bin=n_chans, fs=fs, chan_map=bfile.chan_map,
266+
device=torch_device, dtype='int16'
267+
)
268+
269+
# First and first, last and last, last of original and last of concat
270+
for i,j in [(0,0), (21,21), (22,67)]:
271+
b1 = bfile.padded_batch_to_torch(i, skip_preproc=True)
272+
b2 = bfile3.padded_batch_to_torch(j)
273+
assert torch.allclose(b1, b2)

0 commit comments

Comments
 (0)