Skip to content

Commit 6145f12

Browse files
authored
Merge pull request #122 from LorenFrankLab/missing_channel_indexing
hwChan ordering with unrecorded channels
2 parents 73e14e6 + e0bb9e6 commit 6145f12

File tree

2 files changed

+121
-48
lines changed

2 files changed

+121
-48
lines changed

src/trodes_to_nwb/spike_gadgets_raw_io.py

Lines changed: 69 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# see https://github.com/NeuralEnsemble/python-neo/pull/1303
1010

1111
import functools
12-
from typing import Optional
12+
from typing import List, Optional
1313
from xml.etree import ElementTree
1414

1515
import numpy as np
@@ -71,7 +71,10 @@ def _source_name(self) -> str:
7171

7272
@staticmethod
7373
def _produce_ephys_channel_ids(
74-
n_total_channels: int, n_channels_per_chip: int
74+
n_total_channels: int,
75+
n_channels_recorded: int,
76+
n_channels_per_chip: int,
77+
hw_channels_recorded: List[str] = None,
7578
) -> list[int]:
7679
"""Computes the hardware channel IDs for ephys data.
7780
@@ -85,9 +88,15 @@ def _produce_ephys_channel_ids(
8588
Parameters
8689
----------
8790
n_total_channels : int
91+
Total number of ephys channels in the hardware configuration.
92+
n_channels_recorded : int
8893
Total number of ephys channels recorded.
8994
n_channels_per_chip : int
9095
Number of channels per headstage chip/amplifier.
96+
hw_channels_recorded : list of str, optional
97+
List of hardware channel IDs that were actually recorded. If `None`, all channels are assumed
98+
to be recorded. This is used to filter the returned list if `n_total_channels`
99+
is not equal to `n_channels_recorded`.
91100
92101
Returns
93102
-------
@@ -123,7 +132,19 @@ def _produce_ephys_channel_ids(
123132
for i in range(int(n_total_channels / n_channels_per_chip))
124133
]
125134
)
126-
return [item for sublist in x for item in sublist]
135+
136+
channel_names = [item for sublist in x for item in sublist]
137+
138+
if n_total_channels == n_channels_recorded:
139+
# case where all channels are recorded, no censoring required
140+
return channel_names
141+
142+
if not hw_channels_recorded or len(hw_channels_recorded) != n_channels_recorded:
143+
raise ValueError(
144+
"If n_total_channels != n_channels_recorded, "
145+
"hw_channels_recorded must be provided to censor the returned list."
146+
)
147+
return [x for x in channel_names if str(x) in hw_channels_recorded]
127148

128149
def _parse_header(self):
129150
"""
@@ -167,10 +188,14 @@ def _parse_header(self):
167188
# dt = datetime.datetime.fromtimestamp(int(self.system_time_at_creation) / 1000.0)
168189

169190
self._sampling_rate = float(hconf.attrib["samplingRate"])
170-
num_ephy_channels = int(hconf.attrib["numChannels"])
191+
num_chip_channels = int(
192+
hconf.attrib["numChannels"]
193+
) # number of channels the hardware supports
194+
num_ephy_channels = num_chip_channels # number of channels recorded
171195
# check for agreement with number of channels in xml
172196
sconf_channels = np.sum([len(x) for x in sconf])
173197
if sconf_channels < num_ephy_channels:
198+
# Case: not every channel was saved to recording
174199
num_ephy_channels = sconf_channels
175200
if sconf_channels > num_ephy_channels:
176201
raise ValueError(
@@ -383,46 +408,52 @@ def _parse_header(self):
383408
)
384409
self._mask_channels_bytes[stream_id] = []
385410

386-
channel_ids = self._produce_ephys_channel_ids(
387-
num_ephy_channels, num_chan_per_chip
388-
)
389-
390-
chan_ind = 0
411+
# get list of all hardware channels recorded
412+
hw_channels_recorded = []
391413
for trode in sconf:
392414
for schan in trode:
393-
chan_id = str(channel_ids[chan_ind])
394-
name = "chan" + chan_id
415+
hw_channels_recorded.append(schan.attrib["hwChan"])
395416

396-
# TODO LATER : handle gain correctly according the file version
397-
units = ""
398-
gain = 1.0
399-
offset = 0.0
400-
signal_channels.append(
401-
(
402-
name,
403-
chan_id,
404-
self._sampling_rate,
405-
"int16",
406-
units,
407-
gain,
408-
offset,
409-
stream_id,
410-
"",
411-
)
412-
)
417+
channel_ids = self._produce_ephys_channel_ids(
418+
num_chip_channels,
419+
num_ephy_channels,
420+
num_chan_per_chip,
421+
hw_channels_recorded,
422+
)
413423

414-
chan_mask = np.zeros(packet_size, dtype="bool")
415-
num_bytes_offset = (
416-
packet_size
417-
- (EPHYS_SAMPLE_SIZE_BYTES * num_ephy_channels)
418-
+ (EPHYS_SAMPLE_SIZE_BYTES * chan_ind)
424+
chan_ind = 0
425+
for chan_ind in range(len(channel_ids)):
426+
chan_id = str(channel_ids[chan_ind])
427+
name = "chan" + chan_id
428+
429+
# TODO LATER : handle gain correctly according the file version
430+
units = ""
431+
gain = 1.0
432+
offset = 0.0
433+
signal_channels.append(
434+
(
435+
name,
436+
chan_id,
437+
self._sampling_rate,
438+
"int16",
439+
units,
440+
gain,
441+
offset,
442+
stream_id,
443+
"",
419444
)
420-
chan_mask[
421-
num_bytes_offset : num_bytes_offset + EPHYS_SAMPLE_SIZE_BYTES
422-
] = True
423-
self._mask_channels_bytes[stream_id].append(chan_mask)
445+
)
424446

425-
chan_ind += 1
447+
chan_mask = np.zeros(packet_size, dtype="bool")
448+
num_bytes_offset = (
449+
packet_size
450+
- (EPHYS_SAMPLE_SIZE_BYTES * num_ephy_channels)
451+
+ (EPHYS_SAMPLE_SIZE_BYTES * chan_ind)
452+
)
453+
chan_mask[
454+
num_bytes_offset : num_bytes_offset + EPHYS_SAMPLE_SIZE_BYTES
455+
] = True
456+
self._mask_channels_bytes[stream_id].append(chan_mask)
426457

427458
# make mask as array (used in _get_analogsignal_chunk(...))
428459
self._mask_streams = {}

src/trodes_to_nwb/tests/test_spikegadgets_io.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,9 @@ def test_produce_ephys_channel_ids():
377377
full_expected_1.extend(
378378
[k + i * n_per_chip_1 for i in range(n_total_1 // n_per_chip_1)]
379379
)
380-
result_1 = SpikeGadgetsRawIO._produce_ephys_channel_ids(n_total_1, n_per_chip_1)
380+
result_1 = SpikeGadgetsRawIO._produce_ephys_channel_ids(
381+
n_total_1, n_total_1, n_per_chip_1
382+
)
381383
assert result_1 == full_expected_1
382384
assert len(result_1) == n_total_1
383385

@@ -389,7 +391,9 @@ def test_produce_ephys_channel_ids():
389391
full_expected_2.extend(
390392
[k + i * n_per_chip_2 for i in range(n_total_2 // n_per_chip_2)]
391393
)
392-
result_2 = SpikeGadgetsRawIO._produce_ephys_channel_ids(n_total_2, n_per_chip_2)
394+
result_2 = SpikeGadgetsRawIO._produce_ephys_channel_ids(
395+
n_total_2, n_total_2, n_per_chip_2
396+
)
393397
assert result_2 == full_expected_2
394398
assert len(result_2) == n_total_2
395399

@@ -402,30 +406,68 @@ def test_produce_ephys_channel_ids():
402406
[k + i * n_per_chip_3 for i in range(n_total_3 // n_per_chip_3)]
403407
)
404408

405-
result_3 = SpikeGadgetsRawIO._produce_ephys_channel_ids(n_total_3, n_per_chip_3)
409+
result_3 = SpikeGadgetsRawIO._produce_ephys_channel_ids(
410+
n_total_3, n_total_3, n_per_chip_3
411+
)
406412
assert result_3 == full_expected_3
407413
assert len(result_3) == n_total_3
408414

409415
# Case 4: Single chip - 32 channels, 32 per chip (1 chip)
410416
n_total_4 = 32
411417
n_per_chip_4 = 32
412418
expected_4 = list(range(32)) # Should just be 0, 1, 2, ..., 31
413-
result_4 = SpikeGadgetsRawIO._produce_ephys_channel_ids(n_total_4, n_per_chip_4)
419+
result_4 = SpikeGadgetsRawIO._produce_ephys_channel_ids(
420+
n_total_4, n_total_4, n_per_chip_4
421+
)
414422
assert result_4 == expected_4
415423
assert len(result_4) == n_total_4
416424

425+
# case 5: Not all channels recorded
426+
n_total_5 = 128
427+
n_recorded_5 = 127
428+
n_per_chip_5 = 32
429+
missing_hw_channel = 2
430+
431+
full_expected_5 = []
432+
for k in range(n_per_chip_5):
433+
full_expected_5.extend(
434+
[k + i * n_per_chip_5 for i in range(n_total_5 // n_per_chip_5)]
435+
)
436+
full_expected_5 = [x for x in full_expected_5 if x != missing_hw_channel]
437+
hw_channels_recorded_5 = [
438+
str(x) for x in np.arange(n_total_5) if x != missing_hw_channel
439+
]
440+
result_5 = SpikeGadgetsRawIO._produce_ephys_channel_ids(
441+
n_total_5,
442+
n_recorded_5,
443+
n_per_chip_5,
444+
hw_channels_recorded=hw_channels_recorded_5,
445+
)
446+
assert result_5 == full_expected_5
447+
assert len(result_5) == n_recorded_5
448+
417449
# --- Edge Cases ---
418-
result_5 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 32)
419-
assert result_5 == []
420-
result_6 = SpikeGadgetsRawIO._produce_ephys_channel_ids(128, 0)
450+
result_6 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 0, 32)
421451
assert result_6 == []
422-
result_7 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 0)
452+
result_7 = SpikeGadgetsRawIO._produce_ephys_channel_ids(128, 128, 0)
423453
assert result_7 == []
454+
result_8 = SpikeGadgetsRawIO._produce_ephys_channel_ids(0, 0, 0)
455+
assert result_8 == []
424456

425457
# --- Error Cases ---
426458
with pytest.raises(ValueError) as excinfo:
427-
SpikeGadgetsRawIO._produce_ephys_channel_ids(127, 32)
459+
SpikeGadgetsRawIO._produce_ephys_channel_ids(127, 127, 32)
428460
assert "multiple of channels per chip" in str(excinfo.value)
429461
with pytest.raises(ValueError) as excinfo:
430-
SpikeGadgetsRawIO._produce_ephys_channel_ids(65, 16)
462+
SpikeGadgetsRawIO._produce_ephys_channel_ids(65, 65, 16)
431463
assert "multiple of channels per chip" in str(excinfo.value)
464+
with pytest.raises(ValueError) as excinfo:
465+
SpikeGadgetsRawIO._produce_ephys_channel_ids(
466+
64,
467+
63,
468+
16,
469+
)
470+
assert "hw_channels_recorded must be provided" in str(excinfo.value)
471+
with pytest.raises(ValueError) as excinfo:
472+
SpikeGadgetsRawIO._produce_ephys_channel_ids(64, 63, 16, ["1", "2", "3"])
473+
assert "hw_channels_recorded must be provided" in str(excinfo.value)

0 commit comments

Comments
 (0)