Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/trodes_to_nwb/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,8 @@ def _create_nwb(
stream_id="ECU_analog" if behavior_only else "trodes",
behavior_only=behavior_only,
)
rec_dci_timestamps = (
rec_dci.timestamps
) # pass these when creating other non-interpolated rec iterators to save time
# Defer timestamp loading until needed - this saves memory for large recordings
rec_dci_timestamps = None

rec_header = read_header(rec_filepaths[0])
reconfig_header = rec_header
Expand Down Expand Up @@ -307,7 +306,6 @@ def _create_nwb(
add_analog_data(
nwb_file,
rec_filepaths,
timestamps=rec_dci_timestamps,
behavior_only=behavior_only,
)
logger.info("ADDING SAMPLE COUNTS")
Expand All @@ -332,6 +330,9 @@ def _create_nwb(
session_df,
)
else:
# For non-PTP position tracking, we need timestamps
if rec_dci_timestamps is None:
rec_dci_timestamps = rec_dci.get_timestamps()
add_position(
nwb_file,
metadata,
Expand Down
33 changes: 27 additions & 6 deletions src/trodes_to_nwb/convert_analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
def add_analog_data(
nwbfile: NWBFile,
rec_file_path: list[str],
timestamps: np.ndarray = None,
behavior_only: bool = False,
**kwargs,
) -> None:
Expand Down Expand Up @@ -52,7 +51,6 @@ def add_analog_data(
nwb_hw_channel_order=analog_channel_ids,
stream_id="ECU_analog",
is_analog=True,
timestamps=timestamps,
behavior_only=behavior_only,
)

Expand All @@ -76,17 +74,40 @@ def add_analog_data(
name="analog", description="Contains all analog data"
)
analog_events = pynwb.behavior.BehavioralEvents(name="analog")
analog_events.add_timeseries(
pynwb.TimeSeries(

# Check if we can use sampling rate instead of individual timestamps
sampling_rate = rec_dci.get_sampling_rate()
if sampling_rate is not None:
# Use sampling rate for regular timestamps - much more memory efficient
analog_time_series = pynwb.TimeSeries(
name="analog",
description=__merge_row_description(
analog_channel_ids
), # NOTE: matches rec_to_nwb system
data=data_data_io,
timestamps=rec_dci.timestamps,
rate=sampling_rate,
unit="-1",
)
)
else:
# Use chunked timestamps for irregular timestamps
from trodes_to_nwb import convert_ephys
timestamps_chunked = rec_dci.get_timestamps_chunked()
timestamps_data_io = H5DataIO(
timestamps_chunked,
chunks=(convert_ephys.DEFAULT_CHUNK_TIME_DIM,),
)

analog_time_series = pynwb.TimeSeries(
name="analog",
description=__merge_row_description(
analog_channel_ids
), # NOTE: matches rec_to_nwb system
data=data_data_io,
timestamps=timestamps_data_io,
unit="-1",
)

analog_events.add_timeseries(analog_time_series)
# add it to the nwb file
nwbfile.processing["analog"].add(analog_events)

Expand Down
215 changes: 190 additions & 25 deletions src/trodes_to_nwb/convert_ephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,83 @@
DEFAULT_CHUNK_MAX_CHANNEL_DIM = 32


class TimestampDataChunkIterator(GenericDataChunkIterator):
"""Data chunk iterator for timestamps from SpikeGadgets rec files."""

def __init__(
self,
neo_io_list: list[SpikeGadgetsRawIO],
use_systime: bool = True,
**kwargs,
):
"""
Parameters
----------
neo_io_list : list[SpikeGadgetsRawIO]
list of neo IO objects for the rec files
use_systime : bool, optional
whether to use system time (True) or trodes timestamps (False), by default True
kwargs : dict
additional arguments to pass to GenericDataChunkIterator
"""
self.neo_io_list = neo_io_list
self.use_systime = use_systime
self.n_time = [
neo_io.get_signal_size(
block_index=0, seg_index=0, stream_index=0
)
for neo_io in self.neo_io_list
]
super().__init__(**kwargs)

def _get_data(self, selection: Tuple[slice]) -> np.ndarray:
# selection is (time,)
assert selection[0].step is None

# what global index each file starts at
file_start_ind = np.append(np.zeros(1), np.cumsum(self.n_time))
# the time indexes we want
time_index = np.arange(selection[0].start, selection[0].stop)[
:: selection[0].step
]
timestamps = []
i = time_index[0]
while i < min(time_index[-1], self._get_maxshape()[0]):
# find the stream where this piece of slice begins
io_stream = np.argmin(i >= file_start_ind) - 1
# get the timestamps from that stream
i_start_local = int(i - file_start_ind[io_stream])
i_stop_local = int(
min(
time_index[-1] - file_start_ind[io_stream],
self.n_time[io_stream],
)
) + 1

if self.use_systime and self.neo_io_list[io_stream].sysClock_byte:
chunk_timestamps = self.neo_io_list[io_stream].get_regressed_systime(
i_start_local, i_stop_local
)
else:
chunk_timestamps = self.neo_io_list[io_stream].get_systime_from_trodes_timestamps(
i_start_local, i_stop_local
)
timestamps.append(chunk_timestamps)

i += min(
self.n_time[io_stream] - (i - file_start_ind[io_stream]),
time_index[-1] - i,
)

return np.concatenate(timestamps)

def _get_maxshape(self) -> Tuple[int]:
return (np.sum(self.n_time),)

def _get_dtype(self) -> np.dtype:
return np.dtype("float64")


class RecFileDataChunkIterator(GenericDataChunkIterator):
"""Data chunk iterator for SpikeGadgets rec files."""

Expand Down Expand Up @@ -202,40 +279,109 @@ def __init__(
self.neo_io.pop(iterator_loc)
self.neo_io[iterator_loc:iterator_loc] = sub_iterators
logger.info(f"# iterators: {len(self.neo_io)}")
# NOTE: this will read all the timestamps from the rec file, which can be slow
# Check if timestamps are regular to potentially use sampling rate instead
self._check_timestamp_regularity()

# Only load timestamps into memory if they're needed and irregular
if timestamps is not None:
self.timestamps = timestamps
self._has_loaded_timestamps = True
else:
# Defer timestamp loading - we'll create a chunked iterator later if needed
self.timestamps = None
self._has_loaded_timestamps = False

self.n_time = [
neo_io.get_signal_size(
block_index=self.block_index,
seg_index=self.seg_index,
stream_index=self.stream_index,
)
for neo_io in self.neo_io
]

super().__init__(**kwargs)

elif self.neo_io[0].sysClock_byte: # use this if have sysClock
def _check_timestamp_regularity(self):
"""Check if timestamps are regular by sampling a few chunks."""
# Sample a small portion of timestamps to check regularity
sample_size = min(1000, self.n_time[0]) # Sample first 1000 or all if smaller

try:
if self.neo_io[0].sysClock_byte:
sample_timestamps = self.neo_io[0].get_regressed_systime(0, sample_size)
else:
sample_timestamps = self.neo_io[0].get_systime_from_trodes_timestamps(0, sample_size)

# Check if timestamps are evenly spaced
diffs = np.diff(sample_timestamps)
expected_dt = 1.0 / self.neo_io[0]._sampling_rate
relative_error = np.abs(diffs - expected_dt) / expected_dt

# Consider regular if 95% of intervals are within 1% of expected
self._timestamps_regular = np.percentile(relative_error, 95) < 0.01
self._sampling_rate = self.neo_io[0]._sampling_rate

if len(self.neo_io) > 1:
# For multiple files, we need to be more conservative
self._timestamps_regular = False

except Exception:
# If we can't check, assume irregular
self._timestamps_regular = False
self._sampling_rate = None

def get_timestamps_chunked(self) -> TimestampDataChunkIterator:
"""Get a chunked iterator for timestamps."""
use_systime = self.neo_io[0].sysClock_byte if hasattr(self.neo_io[0], 'sysClock_byte') else False
return TimestampDataChunkIterator(
self.neo_io,
use_systime=use_systime,
)

def get_timestamps(self) -> np.ndarray:
"""Get timestamps, loading them if necessary."""
if self.timestamps is not None:
return self.timestamps
else:
return self.load_all_timestamps()

def get_sampling_rate(self) -> float:
"""Get the sampling rate if timestamps are regular."""
if hasattr(self, '_timestamps_regular') and self._timestamps_regular:
return self._sampling_rate
return None

def load_all_timestamps(self):
"""Load all timestamps into memory (fallback for irregular timestamps)."""
if self._has_loaded_timestamps:
return self.timestamps

logger = logging.getLogger("convert")
logger.info("Loading all timestamps into memory...")

if self.neo_io[0].sysClock_byte: # use this if have sysClock
self.timestamps = np.concatenate(
[neo_io.get_regressed_systime(0, None) for neo_io in self.neo_io]
)

else: # use this to convert Trodes timestamps into systime based on sampling rate
self.timestamps = np.concatenate(
[
neo_io.get_systime_from_trodes_timestamps(0, None)
for neo_io in self.neo_io
]
)


self._has_loaded_timestamps = True
logger.info("Reading timestamps COMPLETE")
is_timestamps_sequential = np.all(np.diff(self.timestamps))

is_timestamps_sequential = np.all(np.diff(self.timestamps) > 0)
if not is_timestamps_sequential:
warn(
"Timestamps are not sequential. This may cause problems with some software or data analysis."
)

self.n_time = [
neo_io.get_signal_size(
block_index=self.block_index,
seg_index=self.seg_index,
stream_index=self.stream_index,
)
for neo_io in self.neo_io
]

super().__init__(**kwargs)

return self.timestamps

def _get_data(self, selection: Tuple[slice]) -> np.ndarray:
# selection is (time, channel)
Expand Down Expand Up @@ -389,14 +535,33 @@ def add_raw_ephys(
),
)

# do we want to pull the timestamps from the rec file? or is there another source?
eseries = ElectricalSeries(
name="e-series",
data=data_data_io,
timestamps=rec_dci.timestamps,
electrodes=electrode_table_region, # TODO
conversion=VOLTS_PER_MICROVOLT,
offset=0.0, # TODO
)
# Check if we can use sampling rate instead of individual timestamps
sampling_rate = rec_dci.get_sampling_rate()
if sampling_rate is not None:
# Use sampling rate for regular timestamps - much more memory efficient
eseries = ElectricalSeries(
name="e-series",
data=data_data_io,
rate=sampling_rate,
electrodes=electrode_table_region,
conversion=VOLTS_PER_MICROVOLT,
offset=0.0,
)
else:
# Use chunked timestamps for irregular timestamps
timestamps_chunked = rec_dci.get_timestamps_chunked()
timestamps_data_io = H5DataIO(
timestamps_chunked,
chunks=(DEFAULT_CHUNK_TIME_DIM,),
)

eseries = ElectricalSeries(
name="e-series",
data=data_data_io,
timestamps=timestamps_data_io,
electrodes=electrode_table_region,
conversion=VOLTS_PER_MICROVOLT,
offset=0.0,
)

nwbfile.add_acquisition(eseries)
2 changes: 1 addition & 1 deletion src/trodes_to_nwb/convert_intervals.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def add_sample_count(
)

# get the systime information
systime = np.array(rec_dci.timestamps) * NANOSECONDS_PER_SECOND
systime = np.array(rec_dci.get_timestamps()) * NANOSECONDS_PER_SECOND
# get the sample count information
trodes_sample = np.concatenate(
[neo_io.get_analogsignal_timestamps(0, None) for neo_io in rec_dci.neo_io]
Expand Down
5 changes: 3 additions & 2 deletions src/trodes_to_nwb/tests/test_behavior_only_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def test_behavior_only_rec_file():
assert "trodes" not in stream_names, "unexpected trodes stream in iterator"

# check data accesses
assert rec_dci.timestamps.size == 433012
assert rec_dci.timestamps[-1] == 1751195974.5656028, "unexpected last timestamp"
timestamps = rec_dci.get_timestamps() # Use new method to get timestamps
assert timestamps.size == 433012
assert timestamps[-1] == 1751195974.5656028, "unexpected last timestamp"
assert set(neo_io.multiplexed_channel_xml.keys()) == set(
[
"Headstage_AccelX",
Expand Down
Loading