Skip to content

Commit c45f230

Browse files
Copilotsamuelbray32
andcommitted
Implement timestamp memory optimization using DataChunkIterator
Co-authored-by: samuelbray32 <[email protected]>
1 parent 064a9d6 commit c45f230

File tree

5 files changed

+226
-38
lines changed

5 files changed

+226
-38
lines changed

src/trodes_to_nwb/convert.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,8 @@ def _create_nwb(
236236
stream_id="ECU_analog" if behavior_only else "trodes",
237237
behavior_only=behavior_only,
238238
)
239-
rec_dci_timestamps = (
240-
rec_dci.timestamps
241-
) # pass these when creating other non-interpolated rec iterators to save time
239+
# Defer timestamp loading until needed - this saves memory for large recordings
240+
rec_dci_timestamps = None
242241

243242
rec_header = read_header(rec_filepaths[0])
244243
reconfig_header = rec_header
@@ -307,7 +306,6 @@ def _create_nwb(
307306
add_analog_data(
308307
nwb_file,
309308
rec_filepaths,
310-
timestamps=rec_dci_timestamps,
311309
behavior_only=behavior_only,
312310
)
313311
logger.info("ADDING SAMPLE COUNTS")
@@ -332,6 +330,9 @@ def _create_nwb(
332330
session_df,
333331
)
334332
else:
333+
# For non-PTP position tracking, we need timestamps
334+
if rec_dci_timestamps is None:
335+
rec_dci_timestamps = rec_dci.get_timestamps()
335336
add_position(
336337
nwb_file,
337338
metadata,

src/trodes_to_nwb/convert_analog.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
def add_analog_data(
1818
nwbfile: NWBFile,
1919
rec_file_path: list[str],
20-
timestamps: np.ndarray = None,
2120
behavior_only: bool = False,
2221
**kwargs,
2322
) -> None:
@@ -52,7 +51,6 @@ def add_analog_data(
5251
nwb_hw_channel_order=analog_channel_ids,
5352
stream_id="ECU_analog",
5453
is_analog=True,
55-
timestamps=timestamps,
5654
behavior_only=behavior_only,
5755
)
5856

@@ -76,17 +74,40 @@ def add_analog_data(
7674
name="analog", description="Contains all analog data"
7775
)
7876
analog_events = pynwb.behavior.BehavioralEvents(name="analog")
79-
analog_events.add_timeseries(
80-
pynwb.TimeSeries(
77+
78+
# Check if we can use sampling rate instead of individual timestamps
79+
sampling_rate = rec_dci.get_sampling_rate()
80+
if sampling_rate is not None:
81+
# Use sampling rate for regular timestamps - much more memory efficient
82+
analog_time_series = pynwb.TimeSeries(
8183
name="analog",
8284
description=__merge_row_description(
8385
analog_channel_ids
8486
), # NOTE: matches rec_to_nwb system
8587
data=data_data_io,
86-
timestamps=rec_dci.timestamps,
88+
rate=sampling_rate,
8789
unit="-1",
8890
)
89-
)
91+
else:
92+
# Use chunked timestamps for irregular timestamps
93+
from trodes_to_nwb import convert_ephys
94+
timestamps_chunked = rec_dci.get_timestamps_chunked()
95+
timestamps_data_io = H5DataIO(
96+
timestamps_chunked,
97+
chunks=(convert_ephys.DEFAULT_CHUNK_TIME_DIM,),
98+
)
99+
100+
analog_time_series = pynwb.TimeSeries(
101+
name="analog",
102+
description=__merge_row_description(
103+
analog_channel_ids
104+
), # NOTE: matches rec_to_nwb system
105+
data=data_data_io,
106+
timestamps=timestamps_data_io,
107+
unit="-1",
108+
)
109+
110+
analog_events.add_timeseries(analog_time_series)
90111
# add it to the nwb file
91112
nwbfile.processing["analog"].add(analog_events)
92113

src/trodes_to_nwb/convert_ephys.py

Lines changed: 190 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,83 @@
3434
DEFAULT_CHUNK_MAX_CHANNEL_DIM = 32
3535

3636

37+
class TimestampDataChunkIterator(GenericDataChunkIterator):
38+
"""Data chunk iterator for timestamps from SpikeGadgets rec files."""
39+
40+
def __init__(
41+
self,
42+
neo_io_list: list[SpikeGadgetsRawIO],
43+
use_systime: bool = True,
44+
**kwargs,
45+
):
46+
"""
47+
Parameters
48+
----------
49+
neo_io_list : list[SpikeGadgetsRawIO]
50+
list of neo IO objects for the rec files
51+
use_systime : bool, optional
52+
whether to use system time (True) or trodes timestamps (False), by default True
53+
kwargs : dict
54+
additional arguments to pass to GenericDataChunkIterator
55+
"""
56+
self.neo_io_list = neo_io_list
57+
self.use_systime = use_systime
58+
self.n_time = [
59+
neo_io.get_signal_size(
60+
block_index=0, seg_index=0, stream_index=0
61+
)
62+
for neo_io in self.neo_io_list
63+
]
64+
super().__init__(**kwargs)
65+
66+
def _get_data(self, selection: Tuple[slice]) -> np.ndarray:
67+
# selection is (time,)
68+
assert selection[0].step is None
69+
70+
# what global index each file starts at
71+
file_start_ind = np.append(np.zeros(1), np.cumsum(self.n_time))
72+
# the time indexes we want
73+
time_index = np.arange(selection[0].start, selection[0].stop)[
74+
:: selection[0].step
75+
]
76+
timestamps = []
77+
i = time_index[0]
78+
while i < min(time_index[-1], self._get_maxshape()[0]):
79+
# find the stream where this piece of slice begins
80+
io_stream = np.argmin(i >= file_start_ind) - 1
81+
# get the timestamps from that stream
82+
i_start_local = int(i - file_start_ind[io_stream])
83+
i_stop_local = int(
84+
min(
85+
time_index[-1] - file_start_ind[io_stream],
86+
self.n_time[io_stream],
87+
)
88+
) + 1
89+
90+
if self.use_systime and self.neo_io_list[io_stream].sysClock_byte:
91+
chunk_timestamps = self.neo_io_list[io_stream].get_regressed_systime(
92+
i_start_local, i_stop_local
93+
)
94+
else:
95+
chunk_timestamps = self.neo_io_list[io_stream].get_systime_from_trodes_timestamps(
96+
i_start_local, i_stop_local
97+
)
98+
timestamps.append(chunk_timestamps)
99+
100+
i += min(
101+
self.n_time[io_stream] - (i - file_start_ind[io_stream]),
102+
time_index[-1] - i,
103+
)
104+
105+
return np.concatenate(timestamps)
106+
107+
def _get_maxshape(self) -> Tuple[int]:
108+
return (np.sum(self.n_time),)
109+
110+
def _get_dtype(self) -> np.dtype:
111+
return np.dtype("float64")
112+
113+
37114
class RecFileDataChunkIterator(GenericDataChunkIterator):
38115
"""Data chunk iterator for SpikeGadgets rec files."""
39116

@@ -202,40 +279,109 @@ def __init__(
202279
self.neo_io.pop(iterator_loc)
203280
self.neo_io[iterator_loc:iterator_loc] = sub_iterators
204281
logger.info(f"# iterators: {len(self.neo_io)}")
205-
# NOTE: this will read all the timestamps from the rec file, which can be slow
282+
# Check if timestamps are regular to potentially use sampling rate instead
283+
self._check_timestamp_regularity()
284+
285+
# Only load timestamps into memory if they're needed and irregular
206286
if timestamps is not None:
207287
self.timestamps = timestamps
288+
self._has_loaded_timestamps = True
289+
else:
290+
# Defer timestamp loading - we'll create a chunked iterator later if needed
291+
self.timestamps = None
292+
self._has_loaded_timestamps = False
293+
294+
self.n_time = [
295+
neo_io.get_signal_size(
296+
block_index=self.block_index,
297+
seg_index=self.seg_index,
298+
stream_index=self.stream_index,
299+
)
300+
for neo_io in self.neo_io
301+
]
302+
303+
super().__init__(**kwargs)
208304

209-
elif self.neo_io[0].sysClock_byte: # use this if have sysClock
305+
def _check_timestamp_regularity(self):
306+
"""Check if timestamps are regular by sampling a few chunks."""
307+
# Sample a small portion of timestamps to check regularity
308+
sample_size = min(1000, self.n_time[0]) # Sample first 1000 or all if smaller
309+
310+
try:
311+
if self.neo_io[0].sysClock_byte:
312+
sample_timestamps = self.neo_io[0].get_regressed_systime(0, sample_size)
313+
else:
314+
sample_timestamps = self.neo_io[0].get_systime_from_trodes_timestamps(0, sample_size)
315+
316+
# Check if timestamps are evenly spaced
317+
diffs = np.diff(sample_timestamps)
318+
expected_dt = 1.0 / self.neo_io[0]._sampling_rate
319+
relative_error = np.abs(diffs - expected_dt) / expected_dt
320+
321+
# Consider regular if 95% of intervals are within 1% of expected
322+
self._timestamps_regular = np.percentile(relative_error, 95) < 0.01
323+
self._sampling_rate = self.neo_io[0]._sampling_rate
324+
325+
if len(self.neo_io) > 1:
326+
# For multiple files, we need to be more conservative
327+
self._timestamps_regular = False
328+
329+
except Exception:
330+
# If we can't check, assume irregular
331+
self._timestamps_regular = False
332+
self._sampling_rate = None
333+
334+
def get_timestamps_chunked(self) -> TimestampDataChunkIterator:
335+
"""Get a chunked iterator for timestamps."""
336+
use_systime = self.neo_io[0].sysClock_byte if hasattr(self.neo_io[0], 'sysClock_byte') else False
337+
return TimestampDataChunkIterator(
338+
self.neo_io,
339+
use_systime=use_systime,
340+
)
341+
342+
def get_timestamps(self) -> np.ndarray:
343+
"""Get timestamps, loading them if necessary."""
344+
if self.timestamps is not None:
345+
return self.timestamps
346+
else:
347+
return self.load_all_timestamps()
348+
349+
def get_sampling_rate(self) -> float:
350+
"""Get the sampling rate if timestamps are regular."""
351+
if hasattr(self, '_timestamps_regular') and self._timestamps_regular:
352+
return self._sampling_rate
353+
return None
354+
355+
def load_all_timestamps(self):
356+
"""Load all timestamps into memory (fallback for irregular timestamps)."""
357+
if self._has_loaded_timestamps:
358+
return self.timestamps
359+
360+
logger = logging.getLogger("convert")
361+
logger.info("Loading all timestamps into memory...")
362+
363+
if self.neo_io[0].sysClock_byte: # use this if have sysClock
210364
self.timestamps = np.concatenate(
211365
[neo_io.get_regressed_systime(0, None) for neo_io in self.neo_io]
212366
)
213-
214367
else: # use this to convert Trodes timestamps into systime based on sampling rate
215368
self.timestamps = np.concatenate(
216369
[
217370
neo_io.get_systime_from_trodes_timestamps(0, None)
218371
for neo_io in self.neo_io
219372
]
220373
)
221-
374+
375+
self._has_loaded_timestamps = True
222376
logger.info("Reading timestamps COMPLETE")
223-
is_timestamps_sequential = np.all(np.diff(self.timestamps))
377+
378+
is_timestamps_sequential = np.all(np.diff(self.timestamps) > 0)
224379
if not is_timestamps_sequential:
225380
warn(
226381
"Timestamps are not sequential. This may cause problems with some software or data analysis."
227382
)
228-
229-
self.n_time = [
230-
neo_io.get_signal_size(
231-
block_index=self.block_index,
232-
seg_index=self.seg_index,
233-
stream_index=self.stream_index,
234-
)
235-
for neo_io in self.neo_io
236-
]
237-
238-
super().__init__(**kwargs)
383+
384+
return self.timestamps
239385

240386
def _get_data(self, selection: Tuple[slice]) -> np.ndarray:
241387
# selection is (time, channel)
@@ -389,14 +535,33 @@ def add_raw_ephys(
389535
),
390536
)
391537

392-
# do we want to pull the timestamps from the rec file? or is there another source?
393-
eseries = ElectricalSeries(
394-
name="e-series",
395-
data=data_data_io,
396-
timestamps=rec_dci.timestamps,
397-
electrodes=electrode_table_region, # TODO
398-
conversion=VOLTS_PER_MICROVOLT,
399-
offset=0.0, # TODO
400-
)
538+
# Check if we can use sampling rate instead of individual timestamps
539+
sampling_rate = rec_dci.get_sampling_rate()
540+
if sampling_rate is not None:
541+
# Use sampling rate for regular timestamps - much more memory efficient
542+
eseries = ElectricalSeries(
543+
name="e-series",
544+
data=data_data_io,
545+
rate=sampling_rate,
546+
electrodes=electrode_table_region,
547+
conversion=VOLTS_PER_MICROVOLT,
548+
offset=0.0,
549+
)
550+
else:
551+
# Use chunked timestamps for irregular timestamps
552+
timestamps_chunked = rec_dci.get_timestamps_chunked()
553+
timestamps_data_io = H5DataIO(
554+
timestamps_chunked,
555+
chunks=(DEFAULT_CHUNK_TIME_DIM,),
556+
)
557+
558+
eseries = ElectricalSeries(
559+
name="e-series",
560+
data=data_data_io,
561+
timestamps=timestamps_data_io,
562+
electrodes=electrode_table_region,
563+
conversion=VOLTS_PER_MICROVOLT,
564+
offset=0.0,
565+
)
401566

402567
nwbfile.add_acquisition(eseries)

src/trodes_to_nwb/convert_intervals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def add_sample_count(
9797
)
9898

9999
# get the systime information
100-
systime = np.array(rec_dci.timestamps) * NANOSECONDS_PER_SECOND
100+
systime = np.array(rec_dci.get_timestamps()) * NANOSECONDS_PER_SECOND
101101
# get the sample count information
102102
trodes_sample = np.concatenate(
103103
[neo_io.get_analogsignal_timestamps(0, None) for neo_io in rec_dci.neo_io]

src/trodes_to_nwb/tests/test_behavior_only_rec.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def test_behavior_only_rec_file():
4444
assert "trodes" not in stream_names, "unexpected trodes stream in iterator"
4545

4646
# check data accesses
47-
assert rec_dci.timestamps.size == 433012
48-
assert rec_dci.timestamps[-1] == 1751195974.5656028, "unexpected last timestamp"
47+
timestamps = rec_dci.get_timestamps() # Use new method to get timestamps
48+
assert timestamps.size == 433012
49+
assert timestamps[-1] == 1751195974.5656028, "unexpected last timestamp"
4950
assert set(neo_io.multiplexed_channel_xml.keys()) == set(
5051
[
5152
"Headstage_AccelX",

0 commit comments

Comments
 (0)