99# see https://github.com/NeuralEnsemble/python-neo/pull/1303
1010
1111import functools
12- from typing import Optional
12+ from typing import List , Optional
1313from xml .etree import ElementTree
1414
1515import numpy as np
@@ -71,7 +71,10 @@ def _source_name(self) -> str:
7171
7272 @staticmethod
7373 def _produce_ephys_channel_ids (
74- n_total_channels : int , n_channels_per_chip : int
74+ n_total_channels : int ,
75+ n_channels_recorded : int ,
76+ n_channels_per_chip : int ,
77+ hw_channels_recorded : List [str ] = None ,
7578 ) -> list [int ]:
7679 """Computes the hardware channel IDs for ephys data.
7780
@@ -85,9 +88,15 @@ def _produce_ephys_channel_ids(
8588 Parameters
8689 ----------
8790 n_total_channels : int
91+ Total number of ephys channels in the hardware configuration.
92+ n_channels_recorded : int
8893 Total number of ephys channels recorded.
8994 n_channels_per_chip : int
9095 Number of channels per headstage chip/amplifier.
96+ hw_channels_recorded : list of str, optional
97+ List of hardware channel IDs that were actually recorded. If `None`, all channels are assumed
98+ to be recorded. This is used to filter the returned list if `n_total_channels`
99+ is not equal to `n_channels_recorded`.
91100
92101 Returns
93102 -------
@@ -123,7 +132,19 @@ def _produce_ephys_channel_ids(
123132 for i in range (int (n_total_channels / n_channels_per_chip ))
124133 ]
125134 )
126- return [item for sublist in x for item in sublist ]
135+
136+ channel_names = [item for sublist in x for item in sublist ]
137+
138+ if n_total_channels == n_channels_recorded :
139+ # case where all channels are recorded, no censoring required
140+ return channel_names
141+
142+ if not hw_channels_recorded or len (hw_channels_recorded ) != n_channels_recorded :
143+ raise ValueError (
144+ "If n_total_channels != n_channels_recorded, "
145+ "hw_channels_recorded must be provided to censor the returned list."
146+ )
147+ return [x for x in channel_names if str (x ) in hw_channels_recorded ]
127148
128149 def _parse_header (self ):
129150 """
@@ -167,10 +188,14 @@ def _parse_header(self):
167188 # dt = datetime.datetime.fromtimestamp(int(self.system_time_at_creation) / 1000.0)
168189
169190 self ._sampling_rate = float (hconf .attrib ["samplingRate" ])
170- num_ephy_channels = int (hconf .attrib ["numChannels" ])
191+ num_chip_channels = int (
192+ hconf .attrib ["numChannels" ]
193+ ) # number of channels the hardware supports
194+ num_ephy_channels = num_chip_channels # number of channels recorded
171195 # check for agreement with number of channels in xml
172196 sconf_channels = np .sum ([len (x ) for x in sconf ])
173197 if sconf_channels < num_ephy_channels :
198+ # Case: not every channel was saved to recording
174199 num_ephy_channels = sconf_channels
175200 if sconf_channels > num_ephy_channels :
176201 raise ValueError (
@@ -383,46 +408,52 @@ def _parse_header(self):
383408 )
384409 self ._mask_channels_bytes [stream_id ] = []
385410
386- channel_ids = self ._produce_ephys_channel_ids (
387- num_ephy_channels , num_chan_per_chip
388- )
389-
390- chan_ind = 0
411+ # get list of all hardware channels recorded
412+ hw_channels_recorded = []
391413 for trode in sconf :
392414 for schan in trode :
393- chan_id = str (channel_ids [chan_ind ])
394- name = "chan" + chan_id
415+ hw_channels_recorded .append (schan .attrib ["hwChan" ])
395416
396- # TODO LATER : handle gain correctly according the file version
397- units = ""
398- gain = 1.0
399- offset = 0.0
400- signal_channels .append (
401- (
402- name ,
403- chan_id ,
404- self ._sampling_rate ,
405- "int16" ,
406- units ,
407- gain ,
408- offset ,
409- stream_id ,
410- "" ,
411- )
412- )
417+ channel_ids = self ._produce_ephys_channel_ids (
418+ num_chip_channels ,
419+ num_ephy_channels ,
420+ num_chan_per_chip ,
421+ hw_channels_recorded ,
422+ )
413423
414- chan_mask = np .zeros (packet_size , dtype = "bool" )
415- num_bytes_offset = (
416- packet_size
417- - (EPHYS_SAMPLE_SIZE_BYTES * num_ephy_channels )
418- + (EPHYS_SAMPLE_SIZE_BYTES * chan_ind )
424+ chan_ind = 0
425+ for chan_ind in range (len (channel_ids )):
426+ chan_id = str (channel_ids [chan_ind ])
427+ name = "chan" + chan_id
428+
429+ # TODO LATER : handle gain correctly according the file version
430+ units = ""
431+ gain = 1.0
432+ offset = 0.0
433+ signal_channels .append (
434+ (
435+ name ,
436+ chan_id ,
437+ self ._sampling_rate ,
438+ "int16" ,
439+ units ,
440+ gain ,
441+ offset ,
442+ stream_id ,
443+ "" ,
419444 )
420- chan_mask [
421- num_bytes_offset : num_bytes_offset + EPHYS_SAMPLE_SIZE_BYTES
422- ] = True
423- self ._mask_channels_bytes [stream_id ].append (chan_mask )
445+ )
424446
425- chan_ind += 1
447+ chan_mask = np .zeros (packet_size , dtype = "bool" )
448+ num_bytes_offset = (
449+ packet_size
450+ - (EPHYS_SAMPLE_SIZE_BYTES * num_ephy_channels )
451+ + (EPHYS_SAMPLE_SIZE_BYTES * chan_ind )
452+ )
453+ chan_mask [
454+ num_bytes_offset : num_bytes_offset + EPHYS_SAMPLE_SIZE_BYTES
455+ ] = True
456+ self ._mask_channels_bytes [stream_id ].append (chan_mask )
426457
427458 # make mask as array (used in _get_analogsignal_chunk(...))
428459 self ._mask_streams = {}
0 commit comments