From 873bf69ed2e0f9da74a1df62fdeb76cd1ef64e55 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Thu, 26 Jun 2025 10:33:49 -0700 Subject: [PATCH 1/6] restrict indices by recorded hwChan --- src/trodes_to_nwb/spike_gadgets_raw_io.py | 103 ++++++++++++++-------- 1 file changed, 67 insertions(+), 36 deletions(-) diff --git a/src/trodes_to_nwb/spike_gadgets_raw_io.py b/src/trodes_to_nwb/spike_gadgets_raw_io.py index e90df38..eb040ab 100644 --- a/src/trodes_to_nwb/spike_gadgets_raw_io.py +++ b/src/trodes_to_nwb/spike_gadgets_raw_io.py @@ -69,9 +69,14 @@ def _source_name(self) -> str: """ return self.filename + from typing import List + @staticmethod def _produce_ephys_channel_ids( - n_total_channels: int, n_channels_per_chip: int + n_total_channels: int, + n_channels_recorded: int, + n_channels_per_chip: int, + hw_channels_recorded: List[str] = None, ) -> list[int]: """Computes the hardware channel IDs for ephys data. @@ -123,7 +128,19 @@ def _produce_ephys_channel_ids( for i in range(int(n_total_channels / n_channels_per_chip)) ] ) - return [item for sublist in x for item in sublist] + + channel_names = [item for sublist in x for item in sublist] + + if n_total_channels == n_channels_recorded: + # case where all channels are recorded, no censoring requuired + return channel_names + + if hw_channels_recorded is None: + raise ValueError( + "If n_total_channels != n_channels_recorded, " + "hw_channels_recorded must be provided to censor the returned list." + ) + return [x for x in channel_names if str(x) in hw_channels_recorded] def _parse_header(self): """ @@ -167,10 +184,16 @@ def _parse_header(self): # dt = datetime.datetime.fromtimestamp(int(self.system_time_at_creation) / 1000.0) self._sampling_rate = float(hconf.attrib["samplingRate"]) - num_ephy_channels = int(hconf.attrib["numChannels"]) + num_chip_channels = int( + hconf.attrib["numChannels"] + ) # number of channels the hardware supports + num_ephy_channels = int( + hconf.attrib["numChannels"] + ) # number of channels recorder # check for agreement with number of channels in xml sconf_channels = np.sum([len(x) for x in sconf]) if sconf_channels < num_ephy_channels: + # Case: not every channel was saved to recording num_ephy_channels = sconf_channels if sconf_channels > num_ephy_channels: raise ValueError( @@ -383,46 +406,54 @@ def _parse_header(self): ) self._mask_channels_bytes[stream_id] = [] + # get list of all hardware channels recorded + hw_channels_recorded = [] + for trode in sconf: + for schan in trode: + hw_channels_recorded.append(schan.attrib["hwChan"]) + channel_ids = self._produce_ephys_channel_ids( - num_ephy_channels, num_chan_per_chip + num_chip_channels, + num_ephy_channels, + num_chan_per_chip, + hw_channels_recorded, ) chan_ind = 0 - for trode in sconf: - for schan in trode: - chan_id = str(channel_ids[chan_ind]) - name = "chan" + chan_id - - # TODO LATER : handle gain correctly according the file version - units = "" - gain = 1.0 - offset = 0.0 - signal_channels.append( - ( - name, - chan_id, - self._sampling_rate, - "int16", - units, - gain, - offset, - stream_id, - "", - ) + for chan_ind in range(len(channel_ids)): + chan_id = str(channel_ids[chan_ind]) + name = "chan" + chan_id + + # TODO LATER : handle gain correctly according the file version + units = "" + gain = 1.0 + offset = 0.0 + signal_channels.append( + ( + name, + chan_id, + self._sampling_rate, + "int16", + units, + gain, + offset, + stream_id, + "", ) + ) - chan_mask = np.zeros(packet_size, dtype="bool") - num_bytes_offset = ( - packet_size - - (EPHYS_SAMPLE_SIZE_BYTES * num_ephy_channels) - + (EPHYS_SAMPLE_SIZE_BYTES * chan_ind) - ) - chan_mask[ - num_bytes_offset : num_bytes_offset + EPHYS_SAMPLE_SIZE_BYTES - ] = True - self._mask_channels_bytes[stream_id].append(chan_mask) + chan_mask = np.zeros(packet_size, dtype="bool") + num_bytes_offset = ( + packet_size + - (EPHYS_SAMPLE_SIZE_BYTES * num_ephy_channels) + + (EPHYS_SAMPLE_SIZE_BYTES * chan_ind) + ) + chan_mask[ + num_bytes_offset : num_bytes_offset + EPHYS_SAMPLE_SIZE_BYTES + ] = True + self._mask_channels_bytes[stream_id].append(chan_mask) - chan_ind += 1 + # chan_ind += 1 # make mask as array (used in _get_analogsignal_chunk(...)) self._mask_streams = {} From 9c32090637b08f0abe2cdcf96ecb8d202233278d Mon Sep 17 00:00:00 2001 From: Samuel Bray Date: Thu, 26 Jun 2025 10:40:32 -0700 Subject: [PATCH 2/6] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/trodes_to_nwb/spike_gadgets_raw_io.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/trodes_to_nwb/spike_gadgets_raw_io.py b/src/trodes_to_nwb/spike_gadgets_raw_io.py index eb040ab..54cb2f5 100644 --- a/src/trodes_to_nwb/spike_gadgets_raw_io.py +++ b/src/trodes_to_nwb/spike_gadgets_raw_io.py @@ -132,7 +132,7 @@ def _produce_ephys_channel_ids( channel_names = [item for sublist in x for item in sublist] if n_total_channels == n_channels_recorded: - # case where all channels are recorded, no censoring requuired + # case where all channels are recorded, no censoring required return channel_names if hw_channels_recorded is None: @@ -453,7 +453,6 @@ def _parse_header(self): ] = True self._mask_channels_bytes[stream_id].append(chan_mask) - # chan_ind += 1 # make mask as array (used in _get_analogsignal_chunk(...)) self._mask_streams = {} From ff45fad1f0fa107ed5f7c6e9357317ea11cb3d97 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Thu, 26 Jun 2025 10:48:49 -0700 Subject: [PATCH 3/6] formatting --- src/trodes_to_nwb/spike_gadgets_raw_io.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/trodes_to_nwb/spike_gadgets_raw_io.py b/src/trodes_to_nwb/spike_gadgets_raw_io.py index 54cb2f5..22fb43b 100644 --- a/src/trodes_to_nwb/spike_gadgets_raw_io.py +++ b/src/trodes_to_nwb/spike_gadgets_raw_io.py @@ -9,7 +9,7 @@ # see https://github.com/NeuralEnsemble/python-neo/pull/1303 import functools -from typing import Optional +from typing import List, Optional from xml.etree import ElementTree import numpy as np @@ -69,8 +69,6 @@ def _source_name(self) -> str: """ return self.filename - from typing import List - @staticmethod def _produce_ephys_channel_ids( n_total_channels: int, @@ -90,9 +88,15 @@ def _produce_ephys_channel_ids( Parameters ---------- n_total_channels : int + Total number of ephys channels in the hardware configuration. + n_channels_recorded : int Total number of ephys channels recorded. n_channels_per_chip : int Number of channels per headstage chip/amplifier. + hw_channels_recorded : list of str, optional + List of hardware channel IDs that were actually recorded. If `None`, all channels are assumed + to be recorded. This is used to filter the returned list if `n_total_channels` + is not equal to `n_channels_recorded`. Returns ------- @@ -453,7 +457,6 @@ def _parse_header(self): ] = True self._mask_channels_bytes[stream_id].append(chan_mask) - # make mask as array (used in _get_analogsignal_chunk(...)) self._mask_streams = {} for stream_id, l in self._mask_channels_bytes.items(): From 000f3ff10f1e4d124f5622970c47c65c94fdec12 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Thu, 26 Jun 2025 11:14:13 -0700 Subject: [PATCH 4/6] update test --- .../tests/test_spikegadgets_io.py | 50 +++++++++++++++---- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/src/trodes_to_nwb/tests/test_spikegadgets_io.py b/src/trodes_to_nwb/tests/test_spikegadgets_io.py index 1cfb34f..854593e 100644 --- a/src/trodes_to_nwb/tests/test_spikegadgets_io.py +++ b/src/trodes_to_nwb/tests/test_spikegadgets_io.py @@ -377,7 +377,9 @@ def test_produce_ephys_channel_ids(): full_expected_1.extend( [k + i * n_per_chip_1 for i in range(n_total_1 // n_per_chip_1)] ) - result_1 = SpikeGadgetsRawIO._produce_ephys_channel_ids(n_total_1, n_per_chip_1) + result_1 = SpikeGadgetsRawIO._produce_ephys_channel_ids( + n_total_1, n_total_1, n_per_chip_1 + ) assert result_1 == full_expected_1 assert len(result_1) == n_total_1 @@ -389,7 +391,9 @@ def test_produce_ephys_channel_ids(): full_expected_2.extend( [k + i * n_per_chip_2 for i in range(n_total_2 // n_per_chip_2)] ) - result_2 = SpikeGadgetsRawIO._produce_ephys_channel_ids(n_total_2, n_per_chip_2) + result_2 = SpikeGadgetsRawIO._produce_ephys_channel_ids( + n_total_2, n_total_2, n_per_chip_2 + ) assert result_2 == full_expected_2 assert len(result_2) == n_total_2 @@ -402,7 +406,9 @@ def test_produce_ephys_channel_ids(): [k + i * n_per_chip_3 for i in range(n_total_3 // n_per_chip_3)] ) - result_3 = SpikeGadgetsRawIO._produce_ephys_channel_ids(n_total_3, n_per_chip_3) + result_3 = SpikeGadgetsRawIO._produce_ephys_channel_ids( + n_total_3, n_total_3, n_per_chip_3 + ) assert result_3 == full_expected_3 assert len(result_3) == n_total_3 @@ -410,22 +416,48 @@ def test_produce_ephys_channel_ids(): n_total_4 = 32 n_per_chip_4 = 32 expected_4 = list(range(32)) # Should just be 0, 1, 2, ..., 31 - result_4 = SpikeGadgetsRawIO._produce_ephys_channel_ids(n_total_4, n_per_chip_4) + result_4 = SpikeGadgetsRawIO._produce_ephys_channel_ids( + n_total_4, n_total_4, n_per_chip_4 + ) assert result_4 == expected_4 assert len(result_4) == n_total_4 + # case 5: Not all channels recorded + n_total_5 = 128 + n_recorded_5 = 127 + n_per_chip_5 = 32 + missing_hw_channel = 2 + + full_expected_5 = [] + for k in range(n_per_chip_5): + full_expected_5.extend( + [k + i * n_per_chip_5 for i in range(n_total_5 // n_per_chip_5)] + ) + full_expected_5 = [x for x in full_expected_5 if x != missing_hw_channel] + hw_channels_recorded_5 = [ + str(x) for x in np.arange(n_total_5) if x != missing_hw_channel + ] + result_5 = SpikeGadgetsRawIO._produce_ephys_channel_ids( + n_total_5, + n_recorded_5, + n_per_chip_5, + hw_channels_recorded=hw_channels_recorded_5, + ) + assert result_5 == full_expected_5 + assert len(result_5) == n_recorded_5 + # --- Edge Cases --- - result_5 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 32) + result_5 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 0, 32) assert result_5 == [] - result_6 = SpikeGadgetsRawIO._produce_ephys_channel_ids(128, 0) + result_6 = SpikeGadgetsRawIO._produce_ephys_channel_ids(128, 128, 0) assert result_6 == [] - result_7 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 0) + result_7 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 0, 0) assert result_7 == [] # --- Error Cases --- with pytest.raises(ValueError) as excinfo: - SpikeGadgetsRawIO._produce_ephys_channel_ids(127, 32) + SpikeGadgetsRawIO._produce_ephys_channel_ids(127, 127, 32) assert "multiple of channels per chip" in str(excinfo.value) with pytest.raises(ValueError) as excinfo: - SpikeGadgetsRawIO._produce_ephys_channel_ids(65, 16) + SpikeGadgetsRawIO._produce_ephys_channel_ids(65, 65, 16) assert "multiple of channels per chip" in str(excinfo.value) From 50f9b8c711febc0a3e496937312e1f93eed1deb7 Mon Sep 17 00:00:00 2001 From: Samuel Bray Date: Thu, 26 Jun 2025 13:23:47 -0700 Subject: [PATCH 5/6] Apply suggestions from code review Co-authored-by: Ryan Ly --- src/trodes_to_nwb/spike_gadgets_raw_io.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/trodes_to_nwb/spike_gadgets_raw_io.py b/src/trodes_to_nwb/spike_gadgets_raw_io.py index 22fb43b..63ab85a 100644 --- a/src/trodes_to_nwb/spike_gadgets_raw_io.py +++ b/src/trodes_to_nwb/spike_gadgets_raw_io.py @@ -139,7 +139,7 @@ def _produce_ephys_channel_ids( # case where all channels are recorded, no censoring required return channel_names - if hw_channels_recorded is None: + if not hw_channels_recorded or len(hw_channels_recorded) != n_channels_recorded: raise ValueError( "If n_total_channels != n_channels_recorded, " "hw_channels_recorded must be provided to censor the returned list." @@ -191,9 +191,7 @@ def _parse_header(self): num_chip_channels = int( hconf.attrib["numChannels"] ) # number of channels the hardware supports - num_ephy_channels = int( - hconf.attrib["numChannels"] - ) # number of channels recorder + num_ephy_channels = num_chip_channels # number of channels recorded # check for agreement with number of channels in xml sconf_channels = np.sum([len(x) for x in sconf]) if sconf_channels < num_ephy_channels: From e0bb9e65626f2a19643487f466587363f6b877aa Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Thu, 26 Jun 2025 13:29:30 -0700 Subject: [PATCH 6/6] update test from review comments --- .../tests/test_spikegadgets_io.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/trodes_to_nwb/tests/test_spikegadgets_io.py b/src/trodes_to_nwb/tests/test_spikegadgets_io.py index 854593e..7f62469 100644 --- a/src/trodes_to_nwb/tests/test_spikegadgets_io.py +++ b/src/trodes_to_nwb/tests/test_spikegadgets_io.py @@ -447,12 +447,12 @@ def test_produce_ephys_channel_ids(): assert len(result_5) == n_recorded_5 # --- Edge Cases --- - result_5 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 0, 32) - assert result_5 == [] - result_6 = SpikeGadgetsRawIO._produce_ephys_channel_ids(128, 128, 0) + result_6 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 0, 32) assert result_6 == [] - result_7 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 0, 0) + result_7 = SpikeGadgetsRawIO._produce_ephys_channel_ids(128, 128, 0) assert result_7 == [] + result_8 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 0, 0) + assert result_8 == [] # --- Error Cases --- with pytest.raises(ValueError) as excinfo: @@ -461,3 +461,13 @@ def test_produce_ephys_channel_ids(): with pytest.raises(ValueError) as excinfo: SpikeGadgetsRawIO._produce_ephys_channel_ids(65, 65, 16) assert "multiple of channels per chip" in str(excinfo.value) + with pytest.raises(ValueError) as excinfo: + SpikeGadgetsRawIO._produce_ephys_channel_ids( + 64, + 63, + 16, + ) + assert "hw_channels_recorded must be provided" in str(excinfo.value) + with pytest.raises(ValueError) as excinfo: + SpikeGadgetsRawIO._produce_ephys_channel_ids(64, 63, 16, ["1", "2", "3"]) + assert "hw_channels_recorded must be provided" in str(excinfo.value)