Skip to content

Commit 12a9b6e

Browse files
authored
Hotfix: Fix broadcasting error when writing ephys data to nwbfile (#179)
1 parent 90b6782 commit 12a9b6e

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

src/jdb_to_nwb/convert_raw_ephys.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
)
2222
from pynwb import NWBFile
2323
from pynwb.ecephys import ElectricalSeries
24-
from hdmf.data_utils import DataChunkIterator
2524
from spikeinterface.extractors import OpenEphysBinaryRecordingExtractor
2625

2726
from .utils import get_logger_directory
@@ -71,6 +70,26 @@
7170
},
7271
}
7372

73+
class MicrovoltsSpikeInterfaceRecordingDataChunkIterator(SpikeInterfaceRecordingDataChunkIterator):
74+
def __init__(self, iterator: SpikeInterfaceRecordingDataChunkIterator, conversion_factor_uv):
75+
self.iterator = iterator
76+
self.conversion_factor_uv = conversion_factor_uv
77+
super().__init__(recording=iterator.recording)
78+
79+
def _get_default_chunk_shape(self, chunk_mb: float = 10.0) -> tuple[int, int]:
80+
return self.iterator._get_default_chunk_shape(chunk_mb)
81+
82+
def _get_data(self, selection: tuple[slice]):
83+
data = self.iterator._get_data(selection)
84+
return (data * self.conversion_factor_uv).astype("int16")
85+
86+
def _get_dtype(self):
87+
return np.dtype("int16")
88+
89+
def _get_maxshape(self):
90+
return self.iterator._get_maxshape()
91+
92+
7493
def find_open_ephys_paths(open_ephys_folder_path, experiment_number=1) -> dict:
7594
"""
7695
Given the Open Ephys folder path, find the relevant settings.xml file and all associated continuous.dat files.
@@ -1167,25 +1186,16 @@ def add_raw_ephys(
11671186
)
11681187

11691188
# Convert to uV without loading the whole thing at once
1170-
def traces_in_microvolts_iterator(traces_as_iterator, conversion_factor_uv):
1171-
for chunk in traces_as_iterator:
1172-
yield (chunk * conversion_factor_uv).astype("int16")
1173-
1174-
# Wrap iterator in DataChunkIterator for H5DataIO
1175-
data_iterator = DataChunkIterator(
1176-
traces_in_microvolts_iterator(traces_as_iterator, channel_conversion_factor_uv),
1177-
buffer_size=1, # number of chunks to keep in memory
1178-
maxshape=(num_samples, num_channels),
1179-
dtype=np.dtype("int16"),
1180-
)
1189+
uv_traces_as_iterator = MicrovoltsSpikeInterfaceRecordingDataChunkIterator(traces_as_iterator,
1190+
channel_conversion_factor_uv)
11811191

11821192
# A chunk of shape (81920, 64) and dtype int16 (2 bytes) is ~10 MB, which is the recommended chunk size
11831193
# by the NWB team.
11841194
# We could also add compression here. zstd/blosc-zstd are recommended by the NWB team, but
11851195
# they require the hdf5plugin library to be installed. gzip is available by default.
11861196
# Use gzip for now, but consider zstd/blosc-zstd in the future.
11871197
data_data_io = H5DataIO(
1188-
data_iterator,
1198+
data=uv_traces_as_iterator,
11891199
chunks=(min(num_samples, 81920), min(num_channels, 64)),
11901200
compression="gzip",
11911201
)

0 commit comments

Comments
 (0)