|
1 | 1 | """Module for handling the conversion of ECU analog and headstage sensor data streams from Trodes .rec files to NWB format.""" |
2 | 2 |
|
| 3 | +import re |
3 | 4 | from xml.etree import ElementTree |
4 | 5 |
|
5 | 6 | import numpy as np |
6 | 7 | import pynwb |
7 | 8 | from hdmf.backends.hdf5 import H5DataIO |
8 | | -from pynwb import NWBFile |
| 9 | +from pynwb import NWBFile, TimeSeries |
9 | 10 |
|
10 | 11 | from trodes_to_nwb import convert_rec_header |
11 | 12 | from trodes_to_nwb.convert_ephys import RecFileDataChunkIterator |
12 | 13 |
|
13 | 14 | DEFAULT_CHUNK_TIME_DIM = 16384 |
14 | 15 | DEFAULT_CHUNK_MAX_CHANNEL_DIM = 32 |
15 | 16 |
|
| 17 | +# Sensor type definitions with scaling factors and units |
| 18 | +SENSOR_TYPE_CONFIG = { |
| 19 | + 'accelerometer': { |
| 20 | + 'pattern': r'Headstage_Accel[XYZ]', |
| 21 | + 'scaling_factor': 0.000061, # Convert to g units |
| 22 | + 'unit': 'g', |
| 23 | + 'description': 'Headstage accelerometer data' |
| 24 | + }, |
| 25 | + 'gyroscope': { |
| 26 | + 'pattern': r'Headstage_Gyro[XYZ]', |
| 27 | + 'scaling_factor': 0.061, # Convert to degrees/second |
| 28 | + 'unit': 'd/s', |
| 29 | + 'description': 'Headstage gyroscope data' |
| 30 | + }, |
| 31 | + 'magnetometer': { |
| 32 | + 'pattern': r'Headstage_Mag[XYZ]', |
| 33 | + 'scaling_factor': 1.0, # No scaling specified in issue |
| 34 | + 'unit': 'unspecified', |
| 35 | + 'description': 'Headstage magnetometer data' |
| 36 | + }, |
| 37 | + 'analog_input': { |
| 38 | + 'pattern': r'(ECU_Ain\d+|Controller_Ain\d+)', |
| 39 | + 'scaling_factor': 1.0, |
| 40 | + 'unit': 'unspecified', |
| 41 | + 'description': 'Analog input channel' |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | + |
| 46 | +def _categorize_sensor_channels(channel_names: list[str]) -> dict[str, list[str]]: |
| 47 | + """Categorize sensor channels by type based on naming patterns. |
| 48 | + |
| 49 | + Parameters |
| 50 | + ---------- |
| 51 | + channel_names : list[str] |
| 52 | + List of channel names to categorize |
| 53 | + |
| 54 | + Returns |
| 55 | + ------- |
| 56 | + dict[str, list[str]] |
| 57 | + Dictionary mapping sensor types to lists of channel names |
| 58 | + """ |
| 59 | + categorized = {} |
| 60 | + |
| 61 | + for sensor_type, config in SENSOR_TYPE_CONFIG.items(): |
| 62 | + pattern = config['pattern'] |
| 63 | + matching_channels = [name for name in channel_names if re.match(pattern, name)] |
| 64 | + if matching_channels: |
| 65 | + categorized[sensor_type] = matching_channels |
| 66 | + |
| 67 | + # Handle uncategorized channels |
| 68 | + categorized_flat = [name for channels in categorized.values() for name in channels] |
| 69 | + uncategorized = [name for name in channel_names if name not in categorized_flat] |
| 70 | + if uncategorized: |
| 71 | + categorized['other'] = uncategorized |
| 72 | + |
| 73 | + return categorized |
| 74 | + |
| 75 | + |
| 76 | +def _create_sensor_timeseries( |
| 77 | + sensor_type: str, |
| 78 | + channel_names: list[str], |
| 79 | + data: np.ndarray, |
| 80 | + timestamps: np.ndarray, |
| 81 | + metadata: dict = None |
| 82 | +) -> TimeSeries: |
| 83 | + """Create a TimeSeries object for a specific sensor type. |
| 84 | + |
| 85 | + Parameters |
| 86 | + ---------- |
| 87 | + sensor_type : str |
| 88 | + Type of sensor (accelerometer, gyroscope, etc.) |
| 89 | + channel_names : list[str] |
| 90 | + Names of channels for this sensor type |
| 91 | + data : np.ndarray |
| 92 | + Raw sensor data |
| 93 | + timestamps : np.ndarray |
| 94 | + Timestamps for the data |
| 95 | + metadata : dict, optional |
| 96 | + Metadata dictionary for custom units/scaling |
| 97 | + |
| 98 | + Returns |
| 99 | + ------- |
| 100 | + TimeSeries |
| 101 | + Configured TimeSeries object for the sensor type |
| 102 | + """ |
| 103 | + config = SENSOR_TYPE_CONFIG.get(sensor_type, { |
| 104 | + 'scaling_factor': 1.0, |
| 105 | + 'unit': 'unspecified', |
| 106 | + 'description': f'{sensor_type} data' |
| 107 | + }) |
| 108 | + |
| 109 | + # Apply scaling factor |
| 110 | + scaled_data = data * config['scaling_factor'] |
| 111 | + |
| 112 | + # Create description with channel names |
| 113 | + description = f"{config['description']}: {', '.join(channel_names)}" |
| 114 | + |
| 115 | + # Use custom units from metadata if available |
| 116 | + unit = config['unit'] |
| 117 | + if metadata and 'sensor_units' in metadata and sensor_type in metadata['sensor_units']: |
| 118 | + unit = metadata['sensor_units'][sensor_type] |
| 119 | + |
| 120 | + return TimeSeries( |
| 121 | + name=sensor_type, |
| 122 | + description=description, |
| 123 | + data=scaled_data, |
| 124 | + unit=unit, |
| 125 | + timestamps=timestamps, |
| 126 | + ) |
| 127 | + |
16 | 128 |
|
17 | 129 | def add_analog_data( |
18 | 130 | nwbfile: NWBFile, |
19 | 131 | rec_file_path: list[str], |
20 | 132 | timestamps: np.ndarray = None, |
21 | 133 | behavior_only: bool = False, |
| 134 | + metadata: dict = None, |
22 | 135 | **kwargs, |
23 | 136 | ) -> None: |
24 | | - """Adds analog streams to the nwb file. |
| 137 | + """Adds analog streams to the nwb file as separate TimeSeries objects for each sensor type. |
25 | 138 |
|
26 | 139 | Parameters |
27 | 140 | ---------- |
28 | 141 | nwbfile : NWBFile |
29 | 142 | nwb file being assembled |
30 | | - recfile : list[str] |
| 143 | + rec_file_path : list[str] |
31 | 144 | ordered list of file paths to all recfiles with session's data |
| 145 | + timestamps : np.ndarray, optional |
| 146 | + timestamps for the data |
| 147 | + behavior_only : bool, optional |
| 148 | + if True, only include behavioral data |
| 149 | + metadata : dict, optional |
| 150 | + metadata dictionary for custom units and scaling |
32 | 151 | """ |
33 | | - # TODO: ADD HEADSTAGE DATA |
34 | | - |
35 | | - # get the ids of the analog channels from the first rec file header |
| 152 | + # Get the ids of the analog channels from the first rec file header |
36 | 153 | root = convert_rec_header.read_header(rec_file_path[0]) |
37 | 154 | hconf = root.find("HardwareConfiguration") |
38 | 155 | ecu_conf = None |
39 | 156 | for conf in hconf: |
40 | 157 | if conf.attrib["name"] == "ECU": |
41 | 158 | ecu_conf = conf |
42 | 159 | break |
43 | | - analog_channel_ids = [] |
| 160 | + |
| 161 | + # Get ECU analog channel IDs |
| 162 | + ecu_analog_channel_ids = [] |
44 | 163 | for channel in ecu_conf: |
45 | 164 | if channel.attrib["dataType"] == "analog": |
46 | | - analog_channel_ids.append(channel.attrib["id"]) |
| 165 | + ecu_analog_channel_ids.append(channel.attrib["id"]) |
47 | 166 |
|
48 | | - # make the data chunk iterator |
49 | | - # TODO use the stream name instead of the stream index to be more robust |
| 167 | + # Make the data chunk iterator for ECU analog data |
50 | 168 | rec_dci = RecFileDataChunkIterator( |
51 | 169 | rec_file_path, |
52 | | - nwb_hw_channel_order=analog_channel_ids, |
| 170 | + nwb_hw_channel_order=ecu_analog_channel_ids, |
53 | 171 | stream_id="ECU_analog", |
54 | 172 | is_analog=True, |
55 | 173 | timestamps=timestamps, |
56 | 174 | behavior_only=behavior_only, |
57 | 175 | ) |
58 | 176 |
|
59 | | - # add headstage channel IDs to the list of analog channel IDs |
60 | | - analog_channel_ids.extend(rec_dci.neo_io[0].multiplexed_channel_xml.keys()) |
61 | | - |
62 | | - # (16384, 32) chunks of dtype int16 (2 bytes) is 1 MB, which is recommended |
63 | | - # by studies by the NWB team. |
64 | | - # could also add compression here. zstd/blosc-zstd are recommended by the NWB team, but |
65 | | - # they require the hdf5plugin library to be installed. gzip is available by default. |
66 | | - data_data_io = H5DataIO( |
67 | | - rec_dci, |
68 | | - chunks=( |
69 | | - DEFAULT_CHUNK_TIME_DIM, |
70 | | - min(len(analog_channel_ids), DEFAULT_CHUNK_MAX_CHANNEL_DIM), |
71 | | - ), |
72 | | - ) |
| 177 | + # Get headstage sensor channel IDs from multiplexed channels |
| 178 | + headstage_channel_ids = list(rec_dci.neo_io[0].multiplexed_channel_xml.keys()) |
| 179 | + all_analog_channel_ids = ecu_analog_channel_ids + headstage_channel_ids |
73 | 180 |
|
74 | | - # make the objects to add to the nwb file |
75 | | - nwbfile.create_processing_module( |
76 | | - name="analog", description="Contains all analog data" |
77 | | - ) |
78 | | - analog_events = pynwb.behavior.BehavioralEvents(name="analog") |
79 | | - analog_events.add_timeseries( |
80 | | - pynwb.TimeSeries( |
81 | | - name="analog", |
82 | | - description=__merge_row_description( |
83 | | - analog_channel_ids |
84 | | - ), # NOTE: matches rec_to_nwb system |
85 | | - data=data_data_io, |
86 | | - timestamps=rec_dci.timestamps, |
87 | | - unit="-1", |
88 | | - ) |
89 | | - ) |
90 | | - # add it to the nwb file |
91 | | - nwbfile.processing["analog"].add(analog_events) |
| 181 | + # Process ECU analog channels if any exist |
| 182 | + if ecu_analog_channel_ids: |
| 183 | + # Get ECU analog data (without headstage data) |
| 184 | + ecu_data = rec_dci._get_data((slice(None), slice(0, len(ecu_analog_channel_ids)))) |
| 185 | + |
| 186 | + # Categorize ECU analog channels |
| 187 | + ecu_categorized = _categorize_sensor_channels(ecu_analog_channel_ids) |
| 188 | + |
| 189 | + # Create TimeSeries for each ECU sensor type |
| 190 | + for sensor_type, channel_names in ecu_categorized.items(): |
| 191 | + channel_indices = [ecu_analog_channel_ids.index(name) for name in channel_names] |
| 192 | + sensor_data = ecu_data[:, channel_indices] |
| 193 | + |
| 194 | + timeseries = _create_sensor_timeseries( |
| 195 | + sensor_type=f"ecu_{sensor_type}", |
| 196 | + channel_names=channel_names, |
| 197 | + data=sensor_data, |
| 198 | + timestamps=rec_dci.timestamps, |
| 199 | + metadata=metadata |
| 200 | + ) |
| 201 | + |
| 202 | + # Add to acquisition |
| 203 | + nwbfile.add_acquisition(timeseries) |
| 204 | + |
| 205 | + # Process headstage sensor channels if any exist |
| 206 | + if headstage_channel_ids: |
| 207 | + # Get headstage sensor data |
| 208 | + headstage_data = rec_dci.neo_io[0].get_analogsignal_multiplexed(headstage_channel_ids) |
| 209 | + |
| 210 | + # Categorize headstage channels by sensor type |
| 211 | + headstage_categorized = _categorize_sensor_channels(headstage_channel_ids) |
| 212 | + |
| 213 | + # Create separate TimeSeries for each sensor type |
| 214 | + for sensor_type, channel_names in headstage_categorized.items(): |
| 215 | + channel_indices = [headstage_channel_ids.index(name) for name in channel_names] |
| 216 | + sensor_data = headstage_data[:, channel_indices] |
| 217 | + |
| 218 | + timeseries = _create_sensor_timeseries( |
| 219 | + sensor_type=sensor_type, |
| 220 | + channel_names=channel_names, |
| 221 | + data=sensor_data, |
| 222 | + timestamps=rec_dci.timestamps, |
| 223 | + metadata=metadata |
| 224 | + ) |
| 225 | + |
| 226 | + # Add to acquisition |
| 227 | + nwbfile.add_acquisition(timeseries) |
92 | 228 |
|
93 | 229 |
|
94 | 230 | def __merge_row_description(row_ids: list[str]) -> str: |
|
0 commit comments