|
34 | 34 | DEFAULT_CHUNK_MAX_CHANNEL_DIM = 32 |
35 | 35 |
|
36 | 36 |
|
| 37 | +class TimestampDataChunkIterator(GenericDataChunkIterator): |
| 38 | + """Data chunk iterator for timestamps from SpikeGadgets rec files.""" |
| 39 | + |
| 40 | + def __init__( |
| 41 | + self, |
| 42 | + neo_io_list: list[SpikeGadgetsRawIO], |
| 43 | + use_systime: bool = True, |
| 44 | + **kwargs, |
| 45 | + ): |
| 46 | + """ |
| 47 | + Parameters |
| 48 | + ---------- |
| 49 | + neo_io_list : list[SpikeGadgetsRawIO] |
| 50 | + list of neo IO objects for the rec files |
| 51 | + use_systime : bool, optional |
| 52 | + whether to use system time (True) or trodes timestamps (False), by default True |
| 53 | + kwargs : dict |
| 54 | + additional arguments to pass to GenericDataChunkIterator |
| 55 | + """ |
| 56 | + self.neo_io_list = neo_io_list |
| 57 | + self.use_systime = use_systime |
| 58 | + self.n_time = [ |
| 59 | + neo_io.get_signal_size( |
| 60 | + block_index=0, seg_index=0, stream_index=0 |
| 61 | + ) |
| 62 | + for neo_io in self.neo_io_list |
| 63 | + ] |
| 64 | + super().__init__(**kwargs) |
| 65 | + |
| 66 | + def _get_data(self, selection: Tuple[slice]) -> np.ndarray: |
| 67 | + # selection is (time,) |
| 68 | + assert selection[0].step is None |
| 69 | + |
| 70 | + # what global index each file starts at |
| 71 | + file_start_ind = np.append(np.zeros(1), np.cumsum(self.n_time)) |
| 72 | + # the time indexes we want |
| 73 | + time_index = np.arange(selection[0].start, selection[0].stop)[ |
| 74 | + :: selection[0].step |
| 75 | + ] |
| 76 | + timestamps = [] |
| 77 | + i = time_index[0] |
| 78 | + while i < min(time_index[-1], self._get_maxshape()[0]): |
| 79 | + # find the stream where this piece of slice begins |
| 80 | + io_stream = np.argmin(i >= file_start_ind) - 1 |
| 81 | + # get the timestamps from that stream |
| 82 | + i_start_local = int(i - file_start_ind[io_stream]) |
| 83 | + i_stop_local = int( |
| 84 | + min( |
| 85 | + time_index[-1] - file_start_ind[io_stream], |
| 86 | + self.n_time[io_stream], |
| 87 | + ) |
| 88 | + ) + 1 |
| 89 | + |
| 90 | + if self.use_systime and self.neo_io_list[io_stream].sysClock_byte: |
| 91 | + chunk_timestamps = self.neo_io_list[io_stream].get_regressed_systime( |
| 92 | + i_start_local, i_stop_local |
| 93 | + ) |
| 94 | + else: |
| 95 | + chunk_timestamps = self.neo_io_list[io_stream].get_systime_from_trodes_timestamps( |
| 96 | + i_start_local, i_stop_local |
| 97 | + ) |
| 98 | + timestamps.append(chunk_timestamps) |
| 99 | + |
| 100 | + i += min( |
| 101 | + self.n_time[io_stream] - (i - file_start_ind[io_stream]), |
| 102 | + time_index[-1] - i, |
| 103 | + ) |
| 104 | + |
| 105 | + return np.concatenate(timestamps) |
| 106 | + |
| 107 | + def _get_maxshape(self) -> Tuple[int]: |
| 108 | + return (np.sum(self.n_time),) |
| 109 | + |
| 110 | + def _get_dtype(self) -> np.dtype: |
| 111 | + return np.dtype("float64") |
| 112 | + |
| 113 | + |
37 | 114 | class RecFileDataChunkIterator(GenericDataChunkIterator): |
38 | 115 | """Data chunk iterator for SpikeGadgets rec files.""" |
39 | 116 |
|
@@ -202,40 +279,109 @@ def __init__( |
202 | 279 | self.neo_io.pop(iterator_loc) |
203 | 280 | self.neo_io[iterator_loc:iterator_loc] = sub_iterators |
204 | 281 | logger.info(f"# iterators: {len(self.neo_io)}") |
205 | | - # NOTE: this will read all the timestamps from the rec file, which can be slow |
| 282 | + # Check if timestamps are regular to potentially use sampling rate instead |
| 283 | + self._check_timestamp_regularity() |
| 284 | + |
| 285 | + # Only load timestamps into memory if they're needed and irregular |
206 | 286 | if timestamps is not None: |
207 | 287 | self.timestamps = timestamps |
| 288 | + self._has_loaded_timestamps = True |
| 289 | + else: |
| 290 | + # Defer timestamp loading - we'll create a chunked iterator later if needed |
| 291 | + self.timestamps = None |
| 292 | + self._has_loaded_timestamps = False |
| 293 | + |
| 294 | + self.n_time = [ |
| 295 | + neo_io.get_signal_size( |
| 296 | + block_index=self.block_index, |
| 297 | + seg_index=self.seg_index, |
| 298 | + stream_index=self.stream_index, |
| 299 | + ) |
| 300 | + for neo_io in self.neo_io |
| 301 | + ] |
| 302 | + |
| 303 | + super().__init__(**kwargs) |
208 | 304 |
|
209 | | - elif self.neo_io[0].sysClock_byte: # use this if have sysClock |
| 305 | + def _check_timestamp_regularity(self): |
| 306 | + """Check if timestamps are regular by sampling a few chunks.""" |
| 307 | + # Sample a small portion of timestamps to check regularity |
| 308 | + sample_size = min(1000, self.n_time[0]) # Sample first 1000 or all if smaller |
| 309 | + |
| 310 | + try: |
| 311 | + if self.neo_io[0].sysClock_byte: |
| 312 | + sample_timestamps = self.neo_io[0].get_regressed_systime(0, sample_size) |
| 313 | + else: |
| 314 | + sample_timestamps = self.neo_io[0].get_systime_from_trodes_timestamps(0, sample_size) |
| 315 | + |
| 316 | + # Check if timestamps are evenly spaced |
| 317 | + diffs = np.diff(sample_timestamps) |
| 318 | + expected_dt = 1.0 / self.neo_io[0]._sampling_rate |
| 319 | + relative_error = np.abs(diffs - expected_dt) / expected_dt |
| 320 | + |
| 321 | + # Consider regular if 95% of intervals are within 1% of expected |
| 322 | + self._timestamps_regular = np.percentile(relative_error, 95) < 0.01 |
| 323 | + self._sampling_rate = self.neo_io[0]._sampling_rate |
| 324 | + |
| 325 | + if len(self.neo_io) > 1: |
| 326 | + # For multiple files, we need to be more conservative |
| 327 | + self._timestamps_regular = False |
| 328 | + |
| 329 | + except Exception: |
| 330 | + # If we can't check, assume irregular |
| 331 | + self._timestamps_regular = False |
| 332 | + self._sampling_rate = None |
| 333 | + |
| 334 | + def get_timestamps_chunked(self) -> TimestampDataChunkIterator: |
| 335 | + """Get a chunked iterator for timestamps.""" |
| 336 | + use_systime = self.neo_io[0].sysClock_byte if hasattr(self.neo_io[0], 'sysClock_byte') else False |
| 337 | + return TimestampDataChunkIterator( |
| 338 | + self.neo_io, |
| 339 | + use_systime=use_systime, |
| 340 | + ) |
| 341 | + |
| 342 | + def get_timestamps(self) -> np.ndarray: |
| 343 | + """Get timestamps, loading them if necessary.""" |
| 344 | + if self.timestamps is not None: |
| 345 | + return self.timestamps |
| 346 | + else: |
| 347 | + return self.load_all_timestamps() |
| 348 | + |
| 349 | + def get_sampling_rate(self) -> float: |
| 350 | + """Get the sampling rate if timestamps are regular.""" |
| 351 | + if hasattr(self, '_timestamps_regular') and self._timestamps_regular: |
| 352 | + return self._sampling_rate |
| 353 | + return None |
| 354 | + |
| 355 | + def load_all_timestamps(self): |
| 356 | + """Load all timestamps into memory (fallback for irregular timestamps).""" |
| 357 | + if self._has_loaded_timestamps: |
| 358 | + return self.timestamps |
| 359 | + |
| 360 | + logger = logging.getLogger("convert") |
| 361 | + logger.info("Loading all timestamps into memory...") |
| 362 | + |
| 363 | + if self.neo_io[0].sysClock_byte: # use this if have sysClock |
210 | 364 | self.timestamps = np.concatenate( |
211 | 365 | [neo_io.get_regressed_systime(0, None) for neo_io in self.neo_io] |
212 | 366 | ) |
213 | | - |
214 | 367 | else: # use this to convert Trodes timestamps into systime based on sampling rate |
215 | 368 | self.timestamps = np.concatenate( |
216 | 369 | [ |
217 | 370 | neo_io.get_systime_from_trodes_timestamps(0, None) |
218 | 371 | for neo_io in self.neo_io |
219 | 372 | ] |
220 | 373 | ) |
221 | | - |
| 374 | + |
| 375 | + self._has_loaded_timestamps = True |
222 | 376 | logger.info("Reading timestamps COMPLETE") |
223 | | - is_timestamps_sequential = np.all(np.diff(self.timestamps)) |
| 377 | + |
| 378 | + is_timestamps_sequential = np.all(np.diff(self.timestamps) > 0) |
224 | 379 | if not is_timestamps_sequential: |
225 | 380 | warn( |
226 | 381 | "Timestamps are not sequential. This may cause problems with some software or data analysis." |
227 | 382 | ) |
228 | | - |
229 | | - self.n_time = [ |
230 | | - neo_io.get_signal_size( |
231 | | - block_index=self.block_index, |
232 | | - seg_index=self.seg_index, |
233 | | - stream_index=self.stream_index, |
234 | | - ) |
235 | | - for neo_io in self.neo_io |
236 | | - ] |
237 | | - |
238 | | - super().__init__(**kwargs) |
| 383 | + |
| 384 | + return self.timestamps |
239 | 385 |
|
240 | 386 | def _get_data(self, selection: Tuple[slice]) -> np.ndarray: |
241 | 387 | # selection is (time, channel) |
@@ -389,14 +535,33 @@ def add_raw_ephys( |
389 | 535 | ), |
390 | 536 | ) |
391 | 537 |
|
392 | | - # do we want to pull the timestamps from the rec file? or is there another source? |
393 | | - eseries = ElectricalSeries( |
394 | | - name="e-series", |
395 | | - data=data_data_io, |
396 | | - timestamps=rec_dci.timestamps, |
397 | | - electrodes=electrode_table_region, # TODO |
398 | | - conversion=VOLTS_PER_MICROVOLT, |
399 | | - offset=0.0, # TODO |
400 | | - ) |
| 538 | + # Check if we can use sampling rate instead of individual timestamps |
| 539 | + sampling_rate = rec_dci.get_sampling_rate() |
| 540 | + if sampling_rate is not None: |
| 541 | + # Use sampling rate for regular timestamps - much more memory efficient |
| 542 | + eseries = ElectricalSeries( |
| 543 | + name="e-series", |
| 544 | + data=data_data_io, |
| 545 | + rate=sampling_rate, |
| 546 | + electrodes=electrode_table_region, |
| 547 | + conversion=VOLTS_PER_MICROVOLT, |
| 548 | + offset=0.0, |
| 549 | + ) |
| 550 | + else: |
| 551 | + # Use chunked timestamps for irregular timestamps |
| 552 | + timestamps_chunked = rec_dci.get_timestamps_chunked() |
| 553 | + timestamps_data_io = H5DataIO( |
| 554 | + timestamps_chunked, |
| 555 | + chunks=(DEFAULT_CHUNK_TIME_DIM,), |
| 556 | + ) |
| 557 | + |
| 558 | + eseries = ElectricalSeries( |
| 559 | + name="e-series", |
| 560 | + data=data_data_io, |
| 561 | + timestamps=timestamps_data_io, |
| 562 | + electrodes=electrode_table_region, |
| 563 | + conversion=VOLTS_PER_MICROVOLT, |
| 564 | + offset=0.0, |
| 565 | + ) |
401 | 566 |
|
402 | 567 | nwbfile.add_acquisition(eseries) |
0 commit comments