diff --git a/src/trodes_to_nwb/convert.py b/src/trodes_to_nwb/convert.py index e4c7cee..5b08a03 100644 --- a/src/trodes_to_nwb/convert.py +++ b/src/trodes_to_nwb/convert.py @@ -236,9 +236,8 @@ def _create_nwb( stream_id="ECU_analog" if behavior_only else "trodes", behavior_only=behavior_only, ) - rec_dci_timestamps = ( - rec_dci.timestamps - ) # pass these when creating other non-interpolated rec iterators to save time + # Defer timestamp loading until needed - this saves memory for large recordings + rec_dci_timestamps = None rec_header = read_header(rec_filepaths[0]) reconfig_header = rec_header @@ -307,7 +306,6 @@ def _create_nwb( add_analog_data( nwb_file, rec_filepaths, - timestamps=rec_dci_timestamps, behavior_only=behavior_only, ) logger.info("ADDING SAMPLE COUNTS") @@ -332,6 +330,9 @@ def _create_nwb( session_df, ) else: + # For non-PTP position tracking, we need timestamps + if rec_dci_timestamps is None: + rec_dci_timestamps = rec_dci.get_timestamps() add_position( nwb_file, metadata, diff --git a/src/trodes_to_nwb/convert_analog.py b/src/trodes_to_nwb/convert_analog.py index d1c5693..73deb44 100644 --- a/src/trodes_to_nwb/convert_analog.py +++ b/src/trodes_to_nwb/convert_analog.py @@ -17,7 +17,6 @@ def add_analog_data( nwbfile: NWBFile, rec_file_path: list[str], - timestamps: np.ndarray = None, behavior_only: bool = False, **kwargs, ) -> None: @@ -52,7 +51,6 @@ def add_analog_data( nwb_hw_channel_order=analog_channel_ids, stream_id="ECU_analog", is_analog=True, - timestamps=timestamps, behavior_only=behavior_only, ) @@ -76,17 +74,40 @@ def add_analog_data( name="analog", description="Contains all analog data" ) analog_events = pynwb.behavior.BehavioralEvents(name="analog") - analog_events.add_timeseries( - pynwb.TimeSeries( + + # Check if we can use sampling rate instead of individual timestamps + sampling_rate = rec_dci.get_sampling_rate() + if sampling_rate is not None: + # Use sampling rate for regular timestamps - much more memory efficient + analog_time_series = pynwb.TimeSeries( name="analog", description=__merge_row_description( analog_channel_ids ), # NOTE: matches rec_to_nwb system data=data_data_io, - timestamps=rec_dci.timestamps, + rate=sampling_rate, unit="-1", ) - ) + else: + # Use chunked timestamps for irregular timestamps + from trodes_to_nwb import convert_ephys + timestamps_chunked = rec_dci.get_timestamps_chunked() + timestamps_data_io = H5DataIO( + timestamps_chunked, + chunks=(convert_ephys.DEFAULT_CHUNK_TIME_DIM,), + ) + + analog_time_series = pynwb.TimeSeries( + name="analog", + description=__merge_row_description( + analog_channel_ids + ), # NOTE: matches rec_to_nwb system + data=data_data_io, + timestamps=timestamps_data_io, + unit="-1", + ) + + analog_events.add_timeseries(analog_time_series) # add it to the nwb file nwbfile.processing["analog"].add(analog_events) diff --git a/src/trodes_to_nwb/convert_ephys.py b/src/trodes_to_nwb/convert_ephys.py index 639452a..a2487ab 100644 --- a/src/trodes_to_nwb/convert_ephys.py +++ b/src/trodes_to_nwb/convert_ephys.py @@ -34,6 +34,83 @@ DEFAULT_CHUNK_MAX_CHANNEL_DIM = 32 +class TimestampDataChunkIterator(GenericDataChunkIterator): + """Data chunk iterator for timestamps from SpikeGadgets rec files.""" + + def __init__( + self, + neo_io_list: list[SpikeGadgetsRawIO], + use_systime: bool = True, + **kwargs, + ): + """ + Parameters + ---------- + neo_io_list : list[SpikeGadgetsRawIO] + list of neo IO objects for the rec files + use_systime : bool, optional + whether to use system time (True) or trodes timestamps (False), by default True + kwargs : dict + additional arguments to pass to GenericDataChunkIterator + """ + self.neo_io_list = neo_io_list + self.use_systime = use_systime + self.n_time = [ + neo_io.get_signal_size( + block_index=0, seg_index=0, stream_index=0 + ) + for neo_io in self.neo_io_list + ] + super().__init__(**kwargs) + + def _get_data(self, selection: Tuple[slice]) -> np.ndarray: + # selection is (time,) + assert selection[0].step is None + + # what global index each file starts at + file_start_ind = np.append(np.zeros(1), np.cumsum(self.n_time)) + # the time indexes we want + time_index = np.arange(selection[0].start, selection[0].stop)[ + :: selection[0].step + ] + timestamps = [] + i = time_index[0] + while i < min(time_index[-1], self._get_maxshape()[0]): + # find the stream where this piece of slice begins + io_stream = np.argmin(i >= file_start_ind) - 1 + # get the timestamps from that stream + i_start_local = int(i - file_start_ind[io_stream]) + i_stop_local = int( + min( + time_index[-1] - file_start_ind[io_stream], + self.n_time[io_stream], + ) + ) + 1 + + if self.use_systime and self.neo_io_list[io_stream].sysClock_byte: + chunk_timestamps = self.neo_io_list[io_stream].get_regressed_systime( + i_start_local, i_stop_local + ) + else: + chunk_timestamps = self.neo_io_list[io_stream].get_systime_from_trodes_timestamps( + i_start_local, i_stop_local + ) + timestamps.append(chunk_timestamps) + + i += min( + self.n_time[io_stream] - (i - file_start_ind[io_stream]), + time_index[-1] - i, + ) + + return np.concatenate(timestamps) + + def _get_maxshape(self) -> Tuple[int]: + return (np.sum(self.n_time),) + + def _get_dtype(self) -> np.dtype: + return np.dtype("float64") + + class RecFileDataChunkIterator(GenericDataChunkIterator): """Data chunk iterator for SpikeGadgets rec files.""" @@ -202,15 +279,91 @@ def __init__( self.neo_io.pop(iterator_loc) self.neo_io[iterator_loc:iterator_loc] = sub_iterators logger.info(f"# iterators: {len(self.neo_io)}") - # NOTE: this will read all the timestamps from the rec file, which can be slow + # Check if timestamps are regular to potentially use sampling rate instead + self._check_timestamp_regularity() + + # Only load timestamps into memory if they're needed and irregular if timestamps is not None: self.timestamps = timestamps + self._has_loaded_timestamps = True + else: + # Defer timestamp loading - we'll create a chunked iterator later if needed + self.timestamps = None + self._has_loaded_timestamps = False + + self.n_time = [ + neo_io.get_signal_size( + block_index=self.block_index, + seg_index=self.seg_index, + stream_index=self.stream_index, + ) + for neo_io in self.neo_io + ] + + super().__init__(**kwargs) - elif self.neo_io[0].sysClock_byte: # use this if have sysClock + def _check_timestamp_regularity(self): + """Check if timestamps are regular by sampling a few chunks.""" + # Sample a small portion of timestamps to check regularity + sample_size = min(1000, self.n_time[0]) # Sample first 1000 or all if smaller + + try: + if self.neo_io[0].sysClock_byte: + sample_timestamps = self.neo_io[0].get_regressed_systime(0, sample_size) + else: + sample_timestamps = self.neo_io[0].get_systime_from_trodes_timestamps(0, sample_size) + + # Check if timestamps are evenly spaced + diffs = np.diff(sample_timestamps) + expected_dt = 1.0 / self.neo_io[0]._sampling_rate + relative_error = np.abs(diffs - expected_dt) / expected_dt + + # Consider regular if 95% of intervals are within 1% of expected + self._timestamps_regular = np.percentile(relative_error, 95) < 0.01 + self._sampling_rate = self.neo_io[0]._sampling_rate + + if len(self.neo_io) > 1: + # For multiple files, we need to be more conservative + self._timestamps_regular = False + + except Exception: + # If we can't check, assume irregular + self._timestamps_regular = False + self._sampling_rate = None + + def get_timestamps_chunked(self) -> TimestampDataChunkIterator: + """Get a chunked iterator for timestamps.""" + use_systime = self.neo_io[0].sysClock_byte if hasattr(self.neo_io[0], 'sysClock_byte') else False + return TimestampDataChunkIterator( + self.neo_io, + use_systime=use_systime, + ) + + def get_timestamps(self) -> np.ndarray: + """Get timestamps, loading them if necessary.""" + if self.timestamps is not None: + return self.timestamps + else: + return self.load_all_timestamps() + + def get_sampling_rate(self) -> float: + """Get the sampling rate if timestamps are regular.""" + if hasattr(self, '_timestamps_regular') and self._timestamps_regular: + return self._sampling_rate + return None + + def load_all_timestamps(self): + """Load all timestamps into memory (fallback for irregular timestamps).""" + if self._has_loaded_timestamps: + return self.timestamps + + logger = logging.getLogger("convert") + logger.info("Loading all timestamps into memory...") + + if self.neo_io[0].sysClock_byte: # use this if have sysClock self.timestamps = np.concatenate( [neo_io.get_regressed_systime(0, None) for neo_io in self.neo_io] ) - else: # use this to convert Trodes timestamps into systime based on sampling rate self.timestamps = np.concatenate( [ @@ -218,24 +371,17 @@ def __init__( for neo_io in self.neo_io ] ) - + + self._has_loaded_timestamps = True logger.info("Reading timestamps COMPLETE") - is_timestamps_sequential = np.all(np.diff(self.timestamps)) + + is_timestamps_sequential = np.all(np.diff(self.timestamps) > 0) if not is_timestamps_sequential: warn( "Timestamps are not sequential. This may cause problems with some software or data analysis." ) - - self.n_time = [ - neo_io.get_signal_size( - block_index=self.block_index, - seg_index=self.seg_index, - stream_index=self.stream_index, - ) - for neo_io in self.neo_io - ] - - super().__init__(**kwargs) + + return self.timestamps def _get_data(self, selection: Tuple[slice]) -> np.ndarray: # selection is (time, channel) @@ -389,14 +535,33 @@ def add_raw_ephys( ), ) - # do we want to pull the timestamps from the rec file? or is there another source? - eseries = ElectricalSeries( - name="e-series", - data=data_data_io, - timestamps=rec_dci.timestamps, - electrodes=electrode_table_region, # TODO - conversion=VOLTS_PER_MICROVOLT, - offset=0.0, # TODO - ) + # Check if we can use sampling rate instead of individual timestamps + sampling_rate = rec_dci.get_sampling_rate() + if sampling_rate is not None: + # Use sampling rate for regular timestamps - much more memory efficient + eseries = ElectricalSeries( + name="e-series", + data=data_data_io, + rate=sampling_rate, + electrodes=electrode_table_region, + conversion=VOLTS_PER_MICROVOLT, + offset=0.0, + ) + else: + # Use chunked timestamps for irregular timestamps + timestamps_chunked = rec_dci.get_timestamps_chunked() + timestamps_data_io = H5DataIO( + timestamps_chunked, + chunks=(DEFAULT_CHUNK_TIME_DIM,), + ) + + eseries = ElectricalSeries( + name="e-series", + data=data_data_io, + timestamps=timestamps_data_io, + electrodes=electrode_table_region, + conversion=VOLTS_PER_MICROVOLT, + offset=0.0, + ) nwbfile.add_acquisition(eseries) diff --git a/src/trodes_to_nwb/convert_intervals.py b/src/trodes_to_nwb/convert_intervals.py index 26cb295..ae9d27c 100644 --- a/src/trodes_to_nwb/convert_intervals.py +++ b/src/trodes_to_nwb/convert_intervals.py @@ -97,7 +97,7 @@ def add_sample_count( ) # get the systime information - systime = np.array(rec_dci.timestamps) * NANOSECONDS_PER_SECOND + systime = np.array(rec_dci.get_timestamps()) * NANOSECONDS_PER_SECOND # get the sample count information trodes_sample = np.concatenate( [neo_io.get_analogsignal_timestamps(0, None) for neo_io in rec_dci.neo_io] diff --git a/src/trodes_to_nwb/tests/test_behavior_only_rec.py b/src/trodes_to_nwb/tests/test_behavior_only_rec.py index 174a32b..e41947e 100644 --- a/src/trodes_to_nwb/tests/test_behavior_only_rec.py +++ b/src/trodes_to_nwb/tests/test_behavior_only_rec.py @@ -44,8 +44,9 @@ def test_behavior_only_rec_file(): assert "trodes" not in stream_names, "unexpected trodes stream in iterator" # check data accesses - assert rec_dci.timestamps.size == 433012 - assert rec_dci.timestamps[-1] == 1751195974.5656028, "unexpected last timestamp" + timestamps = rec_dci.get_timestamps() # Use new method to get timestamps + assert timestamps.size == 433012 + assert timestamps[-1] == 1751195974.5656028, "unexpected last timestamp" assert set(neo_io.multiplexed_channel_xml.keys()) == set( [ "Headstage_AccelX",