Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 69 additions & 38 deletions src/trodes_to_nwb/spike_gadgets_raw_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# see https://github.com/NeuralEnsemble/python-neo/pull/1303

import functools
from typing import Optional
from typing import List, Optional
from xml.etree import ElementTree

import numpy as np
Expand Down Expand Up @@ -71,7 +71,10 @@ def _source_name(self) -> str:

@staticmethod
def _produce_ephys_channel_ids(
n_total_channels: int, n_channels_per_chip: int
n_total_channels: int,
n_channels_recorded: int,
n_channels_per_chip: int,
hw_channels_recorded: List[str] = None,
) -> list[int]:
"""Computes the hardware channel IDs for ephys data.
Expand All @@ -85,9 +88,15 @@ def _produce_ephys_channel_ids(
Parameters
----------
n_total_channels : int
Total number of ephys channels in the hardware configuration.
n_channels_recorded : int
Total number of ephys channels recorded.
n_channels_per_chip : int
Number of channels per headstage chip/amplifier.
hw_channels_recorded : list of str, optional
List of hardware channel IDs that were actually recorded. If `None`, all channels are assumed
to be recorded. This is used to filter the returned list if `n_total_channels`
is not equal to `n_channels_recorded`.
Returns
-------
Expand Down Expand Up @@ -123,7 +132,19 @@ def _produce_ephys_channel_ids(
for i in range(int(n_total_channels / n_channels_per_chip))
]
)
return [item for sublist in x for item in sublist]

channel_names = [item for sublist in x for item in sublist]

if n_total_channels == n_channels_recorded:
# case where all channels are recorded, no censoring required
return channel_names

if not hw_channels_recorded or len(hw_channels_recorded) != n_channels_recorded:
raise ValueError(
"If n_total_channels != n_channels_recorded, "
"hw_channels_recorded must be provided to censor the returned list."
)
return [x for x in channel_names if str(x) in hw_channels_recorded]

def _parse_header(self):
"""
Expand Down Expand Up @@ -167,10 +188,14 @@ def _parse_header(self):
# dt = datetime.datetime.fromtimestamp(int(self.system_time_at_creation) / 1000.0)

self._sampling_rate = float(hconf.attrib["samplingRate"])
num_ephy_channels = int(hconf.attrib["numChannels"])
num_chip_channels = int(
hconf.attrib["numChannels"]
) # number of channels the hardware supports
num_ephy_channels = num_chip_channels # number of channels recorded
# check for agreement with number of channels in xml
sconf_channels = np.sum([len(x) for x in sconf])
if sconf_channels < num_ephy_channels:
# Case: not every channel was saved to recording
num_ephy_channels = sconf_channels
if sconf_channels > num_ephy_channels:
raise ValueError(
Expand Down Expand Up @@ -383,46 +408,52 @@ def _parse_header(self):
)
self._mask_channels_bytes[stream_id] = []

channel_ids = self._produce_ephys_channel_ids(
num_ephy_channels, num_chan_per_chip
)

chan_ind = 0
# get list of all hardware channels recorded
hw_channels_recorded = []
for trode in sconf:
for schan in trode:
chan_id = str(channel_ids[chan_ind])
name = "chan" + chan_id
hw_channels_recorded.append(schan.attrib["hwChan"])

# TODO LATER : handle gain correctly according the file version
units = ""
gain = 1.0
offset = 0.0
signal_channels.append(
(
name,
chan_id,
self._sampling_rate,
"int16",
units,
gain,
offset,
stream_id,
"",
)
)
channel_ids = self._produce_ephys_channel_ids(
num_chip_channels,
num_ephy_channels,
num_chan_per_chip,
hw_channels_recorded,
)

chan_mask = np.zeros(packet_size, dtype="bool")
num_bytes_offset = (
packet_size
- (EPHYS_SAMPLE_SIZE_BYTES * num_ephy_channels)
+ (EPHYS_SAMPLE_SIZE_BYTES * chan_ind)
chan_ind = 0
for chan_ind in range(len(channel_ids)):
chan_id = str(channel_ids[chan_ind])
name = "chan" + chan_id

# TODO LATER : handle gain correctly according the file version
units = ""
gain = 1.0
offset = 0.0
signal_channels.append(
(
name,
chan_id,
self._sampling_rate,
"int16",
units,
gain,
offset,
stream_id,
"",
)
chan_mask[
num_bytes_offset : num_bytes_offset + EPHYS_SAMPLE_SIZE_BYTES
] = True
self._mask_channels_bytes[stream_id].append(chan_mask)
)

chan_ind += 1
chan_mask = np.zeros(packet_size, dtype="bool")
num_bytes_offset = (
packet_size
- (EPHYS_SAMPLE_SIZE_BYTES * num_ephy_channels)
+ (EPHYS_SAMPLE_SIZE_BYTES * chan_ind)
)
chan_mask[
num_bytes_offset : num_bytes_offset + EPHYS_SAMPLE_SIZE_BYTES
] = True
self._mask_channels_bytes[stream_id].append(chan_mask)

# make mask as array (used in _get_analogsignal_chunk(...))
self._mask_streams = {}
Expand Down
62 changes: 52 additions & 10 deletions src/trodes_to_nwb/tests/test_spikegadgets_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,9 @@ def test_produce_ephys_channel_ids():
full_expected_1.extend(
[k + i * n_per_chip_1 for i in range(n_total_1 // n_per_chip_1)]
)
result_1 = SpikeGadgetsRawIO._produce_ephys_channel_ids(n_total_1, n_per_chip_1)
result_1 = SpikeGadgetsRawIO._produce_ephys_channel_ids(
n_total_1, n_total_1, n_per_chip_1
)
assert result_1 == full_expected_1
assert len(result_1) == n_total_1

Expand All @@ -389,7 +391,9 @@ def test_produce_ephys_channel_ids():
full_expected_2.extend(
[k + i * n_per_chip_2 for i in range(n_total_2 // n_per_chip_2)]
)
result_2 = SpikeGadgetsRawIO._produce_ephys_channel_ids(n_total_2, n_per_chip_2)
result_2 = SpikeGadgetsRawIO._produce_ephys_channel_ids(
n_total_2, n_total_2, n_per_chip_2
)
assert result_2 == full_expected_2
assert len(result_2) == n_total_2

Expand All @@ -402,30 +406,68 @@ def test_produce_ephys_channel_ids():
[k + i * n_per_chip_3 for i in range(n_total_3 // n_per_chip_3)]
)

result_3 = SpikeGadgetsRawIO._produce_ephys_channel_ids(n_total_3, n_per_chip_3)
result_3 = SpikeGadgetsRawIO._produce_ephys_channel_ids(
n_total_3, n_total_3, n_per_chip_3
)
assert result_3 == full_expected_3
assert len(result_3) == n_total_3

# Case 4: Single chip - 32 channels, 32 per chip (1 chip)
n_total_4 = 32
n_per_chip_4 = 32
expected_4 = list(range(32)) # Should just be 0, 1, 2, ..., 31
result_4 = SpikeGadgetsRawIO._produce_ephys_channel_ids(n_total_4, n_per_chip_4)
result_4 = SpikeGadgetsRawIO._produce_ephys_channel_ids(
n_total_4, n_total_4, n_per_chip_4
)
assert result_4 == expected_4
assert len(result_4) == n_total_4

# case 5: Not all channels recorded
n_total_5 = 128
n_recorded_5 = 127
n_per_chip_5 = 32
missing_hw_channel = 2

full_expected_5 = []
for k in range(n_per_chip_5):
full_expected_5.extend(
[k + i * n_per_chip_5 for i in range(n_total_5 // n_per_chip_5)]
)
full_expected_5 = [x for x in full_expected_5 if x != missing_hw_channel]
hw_channels_recorded_5 = [
str(x) for x in np.arange(n_total_5) if x != missing_hw_channel
]
result_5 = SpikeGadgetsRawIO._produce_ephys_channel_ids(
n_total_5,
n_recorded_5,
n_per_chip_5,
hw_channels_recorded=hw_channels_recorded_5,
)
assert result_5 == full_expected_5
assert len(result_5) == n_recorded_5

# --- Edge Cases ---
result_5 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 32)
assert result_5 == []
result_6 = SpikeGadgetsRawIO._produce_ephys_channel_ids(128, 0)
result_6 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 0, 32)
assert result_6 == []
result_7 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 0)
result_7 = SpikeGadgetsRawIO._produce_ephys_channel_ids(128, 128, 0)
assert result_7 == []
result_8 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 0, 0)
assert result_8 == []

# --- Error Cases ---
with pytest.raises(ValueError) as excinfo:
SpikeGadgetsRawIO._produce_ephys_channel_ids(127, 32)
SpikeGadgetsRawIO._produce_ephys_channel_ids(127, 127, 32)
assert "multiple of channels per chip" in str(excinfo.value)
with pytest.raises(ValueError) as excinfo:
SpikeGadgetsRawIO._produce_ephys_channel_ids(65, 16)
SpikeGadgetsRawIO._produce_ephys_channel_ids(65, 65, 16)
assert "multiple of channels per chip" in str(excinfo.value)
with pytest.raises(ValueError) as excinfo:
SpikeGadgetsRawIO._produce_ephys_channel_ids(
64,
63,
16,
)
assert "hw_channels_recorded must be provided" in str(excinfo.value)
with pytest.raises(ValueError) as excinfo:
SpikeGadgetsRawIO._produce_ephys_channel_ids(64, 63, 16, ["1", "2", "3"])
assert "hw_channels_recorded must be provided" in str(excinfo.value)
Loading