diff --git a/doc/api.rst b/doc/api.rst index 2aa09767a9..9732d70b18 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -73,7 +73,7 @@ Low-level .. automodule:: spikeinterface.core :noindex: - .. autoclass:: ChunkRecordingExecutor + .. autoclass:: ChunkExecutor Back-compatibility with ``WaveformExtractor`` (version > 0.100.0) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 2b7180117b..e4651c0ef4 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -222,7 +222,7 @@ def __init__( for segment_index in range(multisortingcomparison._num_segments): sorting_segment = AgreementSortingSegment(multisortingcomparison._spiketrains[segment_index]) - self.add_sorting_segment(sorting_segment) + self.add_segment(sorting_segment) self._kwargs = dict( sampling_frequency=sampling_frequency, diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 44d805377f..c8d9b6e94c 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -95,12 +95,12 @@ get_best_job_kwargs, ensure_n_jobs, ensure_chunk_size, - ChunkRecordingExecutor, + ChunkExecutor, split_job_kwargs, fix_job_kwargs, ) +from .chunkable_tools import write_binary, write_memory from .recording_tools import ( - write_binary_recording, write_to_h5_dataset_format, get_random_data_chunks, get_channel_distances, diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 6fe49b8606..44ecea18d4 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -1,4 +1,5 @@ from __future__ import annotations +from abc import ABC, abstractmethod from pathlib import Path import shutil from typing import Any, Iterable, List, Optional, Sequence, Union @@ -61,6 +62,9 @@ def __init__(self, main_ids: Sequence) -> None: self._main_ids.dtype.kind in "uiSU" ), f"Main IDs can only be integers (signed/unsigned) or strings, not {self._main_ids.dtype}" + # segments + self._segments: List[BaseSegment] = [] + # dict at object level self._annotations = {} @@ -116,9 +120,16 @@ def name(self, value): # we remove the annotation if it exists _ = self._annotations.pop("name", None) + @property + def segments(self) -> list: + return self._segments + + def add_segment(self, segment: BaseSegment) -> None: + self._segments.append(segment) + segment.set_parent_extractor(self) + def get_num_segments(self) -> int: - # This is implemented in BaseRecording or BaseSorting - raise NotImplementedError + return len(self._segments) def get_parent(self) -> Optional[BaseExtractor]: """Returns parent object if it exists, otherwise None""" @@ -210,13 +221,6 @@ def set_annotation(self, annotation_key: str, value: Any, overwrite=False) -> No else: raise ValueError(f"{annotation_key} is already an annotation key. Use 'overwrite=True' to overwrite it") - def get_preferred_mp_context(self): - """ - Get the preferred context for multiprocessing. - If None, the context is set by the multiprocessing package. - """ - return self._preferred_mp_context - def get_annotation(self, key: str, copy: bool = True) -> Any: """ Get a annotation. @@ -1165,3 +1169,42 @@ def parent_extractor(self) -> Union[BaseExtractor, None]: def set_parent_extractor(self, parent_extractor: BaseExtractor) -> None: self._parent_extractor = weakref.ref(parent_extractor) + + +class ChunkableMixin(ABC): + """ + Abstract mixin class for chunkable objects. + Provides methods to handle chunked data access, that can be used for parallelization. + + The Mixin is abstract since all methods need to be implemented in the child class in order + for it to function properly. + """ + + _preferred_mp_context = None + + @abstractmethod + def get_sampling_frequency(self) -> float: + raise NotImplementedError + + @abstractmethod + def get_num_samples(self, segment_index: int | None = None) -> int: + raise NotImplementedError + + @abstractmethod + def get_sample_size_in_bytes(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_shape(self, segment_index: int | None = None) -> tuple[int, ...]: + raise NotImplementedError + + @abstractmethod + def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray: + raise NotImplementedError + + def get_preferred_mp_context(self): + """ + Get the preferred context for multiprocessing. + If None, the context is set by the multiprocessing package. + """ + return self._preferred_mp_context diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index bbb6e5b9f3..fbfa70716a 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -7,14 +7,13 @@ import numpy as np from probeinterface import read_probeinterface, write_probeinterface -from .base import BaseSegment +from .base import BaseSegment, ChunkableMixin from .baserecordingsnippets import BaseRecordingSnippets from .core_tools import convert_bytes_to_str, convert_seconds_to_str from .job_tools import split_job_kwargs -from .recording_tools import write_binary_recording -class BaseRecording(BaseRecordingSnippets): +class BaseRecording(BaseRecordingSnippets, ChunkableMixin): """ Abstract class representing several a multichannel timeseries (or block of raw ephys traces). Internally handle list of RecordingSegment @@ -44,8 +43,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype): self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype ) - self._recording_segments: list[BaseRecordingSegment] = [] - # initialize main annotation and properties self.annotate(is_filtered=False) @@ -161,28 +158,19 @@ def _repr_html_(self, display_name=True): html_repr = html_header + html_segments + html_channel_ids + html_extra return html_repr - def get_num_segments(self) -> int: + def get_sample_size_in_bytes(self): """ - Returns the number of segments. + Returns the size of a single sample across all channels in bytes. Returns ------- int - Number of segments in the recording + The size of a single sample in bytes """ - return len(self._recording_segments) - - def add_recording_segment(self, recording_segment): - """Adds a recording segment. - - Parameters - ---------- - recording_segment : BaseRecordingSegment - The recording segment to add - """ - # todo: check channel count and sampling frequency - self._recording_segments.append(recording_segment) - recording_segment.set_parent_extractor(self) + num_channels = self.get_num_channels() + dtype_size_bytes = self.get_dtype().itemsize + sample_size = num_channels * dtype_size_bytes + return sample_size def get_num_samples(self, segment_index: int | None = None) -> int: """ @@ -201,7 +189,7 @@ def get_num_samples(self, segment_index: int | None = None) -> int: The number of samples """ segment_index = self._check_segment_index(segment_index) - return int(self._recording_segments[segment_index].get_num_samples()) + return int(self.segments[segment_index].get_num_samples()) get_num_frames = get_num_samples @@ -333,7 +321,7 @@ def get_traces( """ segment_index = self._check_segment_index(segment_index) channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] start_frame = int(start_frame) if start_frame is not None else 0 num_samples = rs.get_num_samples() end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples @@ -391,7 +379,7 @@ def get_time_info(self, segment_index=None) -> dict: """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] time_kwargs = rs.get_times_kwargs() return time_kwargs @@ -415,7 +403,7 @@ def get_times(self, segment_index=None) -> np.ndarray: The 1d times array """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] times = rs.get_times() return times @@ -433,7 +421,7 @@ def get_start_time(self, segment_index=None) -> float: The start time in seconds """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] return rs.get_start_time() def get_end_time(self, segment_index=None) -> float: @@ -450,7 +438,7 @@ def get_end_time(self, segment_index=None) -> float: The stop time in seconds """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] return rs.get_end_time() def has_time_vector(self, segment_index: Optional[int] = None): @@ -467,7 +455,7 @@ def has_time_vector(self, segment_index: Optional[int] = None): True if the recording has time vectors, False otherwise """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] d = rs.get_times_kwargs() return d["time_vector"] is not None @@ -484,7 +472,7 @@ def set_times(self, times, segment_index=None, with_warning=True): If True, a warning is printed """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] assert times.ndim == 1, "Time must have ndim=1" assert rs.get_num_samples() == times.shape[0], "times have wrong shape" @@ -507,7 +495,7 @@ def reset_times(self): segment's sampling frequency is set to the recording's sampling frequency. """ for segment_index in range(self.get_num_segments()): - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] if self.has_time_vector(segment_index): rs.time_vector = None rs.t_start = None @@ -535,7 +523,7 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N segments_to_shift = (segment_index,) for segment_index in segments_to_shift: - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] if self.has_time_vector(segment_index=segment_index): rs.time_vector += shift @@ -548,19 +536,28 @@ def sample_index_to_time(self, sample_ind, segment_index=None): Transform sample index into time in seconds """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] return rs.sample_index_to_time(sample_ind) def time_to_sample_index(self, time_s, segment_index=None): segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] return rs.time_to_sample_index(time_s) + def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray: + """ + General retrieval function for chunkable objects + """ + return self.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame, **kwargs) + + def get_shape(self, segment_index: int | None = None) -> tuple[int, ...]: + return (self.get_num_samples(segment_index=segment_index), self.get_num_channels()) + def _get_t_starts(self): # handle t_starts t_starts = [] has_time_vectors = [] - for rs in self._recording_segments: + for rs in self.segments: d = rs.get_times_kwargs() t_starts.append(d["t_start"]) @@ -570,7 +567,7 @@ def _get_t_starts(self): def _get_time_vectors(self): time_vectors = [] - for rs in self._recording_segments: + for rs in self.segments: d = rs.get_times_kwargs() time_vectors.append(d["time_vector"]) if all(time_vector is None for time_vector in time_vectors): @@ -581,12 +578,14 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): kwargs, job_kwargs = split_job_kwargs(save_kwargs) if format == "binary": + from .chunkable_tools import write_binary + folder = kwargs["folder"] file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] dtype = kwargs.get("dtype", None) or self.get_dtype() t_starts = self._get_t_starts() - write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) + write_binary(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) from .binaryrecordingextractor import BinaryRecordingExtractor @@ -658,7 +657,7 @@ def _extra_metadata_from_folder(self, folder): self.set_probegroup(probegroup, in_place=True) # load time vector if any - for segment_index, rs in enumerate(self._recording_segments): + for segment_index, rs in enumerate(self.segments): time_file = folder / f"times_cached_seg{segment_index}.npy" if time_file.is_file(): time_vector = np.load(time_file) @@ -671,7 +670,7 @@ def _extra_metadata_to_folder(self, folder): write_probeinterface(folder / "probe.json", probegroup) # save time vector if any - for segment_index, rs in enumerate(self._recording_segments): + for segment_index, rs in enumerate(self.segments): d = rs.get_times_kwargs() time_vector = d["time_vector"] if time_vector is not None: diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 98159fb646..6c8101e0ab 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -20,7 +20,6 @@ class BaseSorting(BaseExtractor): def __init__(self, sampling_frequency: float, unit_ids: list): BaseExtractor.__init__(self, unit_ids) self._sampling_frequency = float(sampling_frequency) - self._sorting_segments: list[BaseSortingSegment] = [] # this weak link is to handle times from a recording object self._recording = None self._sorting_info = None @@ -73,16 +72,9 @@ def get_unit_ids(self) -> list: def get_num_units(self) -> int: return len(self.get_unit_ids()) - def add_sorting_segment(self, sorting_segment): - self._sorting_segments.append(sorting_segment) - sorting_segment.set_parent_extractor(self) - def get_sampling_frequency(self) -> float: return self._sampling_frequency - def get_num_segments(self) -> int: - return len(self._sorting_segments) - def get_num_samples(self, segment_index=None) -> int: """Returns the number of samples of the associated recording for a segment. @@ -182,7 +174,7 @@ def get_unit_spike_train( if segment_index not in self._cached_spike_trains: self._cached_spike_trains[segment_index] = {} if unit_id not in self._cached_spike_trains[segment_index]: - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] spike_frames = segment.get_unit_spike_train(unit_id=unit_id, start_frame=None, end_frame=None).astype( "int64", copy=False ) @@ -196,7 +188,7 @@ def get_unit_spike_train( end = np.searchsorted(spike_frames, end_frame) spike_frames = spike_frames[:end] else: - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] spike_frames = segment.get_unit_spike_train( unit_id=unit_id, start_frame=start_frame, end_frame=end_frame ).astype("int64") @@ -240,7 +232,7 @@ def get_unit_spike_train_in_seconds( Spike times in seconds """ segment_index = self._check_segment_index(segment_index) - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] # If sorting has a registered recording, get the frames and get the times from the recording # Note that this take into account the segment start time of the recording @@ -674,7 +666,7 @@ def time_to_sample_index(self, time, segment_index=0): if self.has_recording(): sample_index = self._recording.time_to_sample_index(time, segment_index=segment_index) else: - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] t_start = segment._t_start if segment._t_start is not None else 0 sample_index = round((time - t_start) * self.get_sampling_frequency()) @@ -690,7 +682,7 @@ def sample_index_to_time( if self.has_recording(): return self._recording.sample_index_to_time(sample_index, segment_index=segment_index) else: - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] t_start = segment._t_start if segment._t_start is not None else 0 return (sample_index / self.get_sampling_frequency()) + t_start diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 15c304e846..3c8e62f56d 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -7,7 +7,6 @@ from .baserecording import BaseRecording, BaseRecordingSegment from .core_tools import define_function_from_class -from .recording_tools import write_binary_recording from .job_tools import _shared_job_kwargs_doc @@ -93,7 +92,7 @@ def __init__( rec_segment = BinaryRecordingSegment( file_path, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset ) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) if is_filtered is not None: self.annotate(is_filtered=is_filtered) @@ -133,7 +132,9 @@ def write_recording(recording, file_paths, dtype=None, **job_kwargs): Type of the saved data {} """ - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) + from .chunkable_tools import write_binary + + write_binary(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) def is_binary_compatible(self) -> bool: return True @@ -154,8 +155,8 @@ def __del__(self): Closes any open file handles in the recording segments. """ # Close all recording segments - if hasattr(self, "_recording_segments"): - for segment in self._recording_segments: + if hasattr(self, "segments"): + for segment in self.segments: # This will trigger the __del__ method of the BinaryRecordingSegment # which will close the file handle del segment diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 996772fd33..6ede4e9e66 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -128,9 +128,9 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record ch_id += 1 for i_seg in range(num_segments): - parent_segments = [rec._recording_segments[i_seg] for rec in recording_list] + parent_segments = [rec.segments[i_seg] for rec in recording_list] sub_segment = ChannelsAggregationRecordingSegment(channel_map, parent_segments) - self.add_recording_segment(sub_segment) + self.add_segment(sub_segment) self._kwargs = {"recording_list": recording_list, "renamed_channel_ids": renamed_channel_ids} diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 8a4f29e86c..245d45acb5 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -56,9 +56,9 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) self._parent_channel_indices = parent_recording.ids_to_indices(self._channel_ids) # link recording segment - for parent_segment in parent_recording._recording_segments: + for parent_segment in parent_recording.segments: sub_segment = ChannelSliceRecordingSegment(parent_segment, self._parent_channel_indices) - self.add_recording_segment(sub_segment) + self.add_segment(sub_segment) # copy annotation and properties parent_recording.copy_metadata(self, only_main=False, ids=self._channel_ids) diff --git a/src/spikeinterface/core/chunkable_tools.py b/src/spikeinterface/core/chunkable_tools.py new file mode 100644 index 0000000000..3371af43d6 --- /dev/null +++ b/src/spikeinterface/core/chunkable_tools.py @@ -0,0 +1,365 @@ +from __future__ import annotations +from pathlib import Path +import warnings + + +import numpy as np + +from .core_tools import add_suffix, make_shared_array +from .job_tools import ( + chunk_duration_to_chunk_size, + ensure_n_jobs, + fix_job_kwargs, + ChunkExecutor, + _shared_job_kwargs_doc, +) + + +def write_binary( + chunkable: "ChunkableMixin", + file_paths: list[Path | str] | Path | str, + dtype: np.typing.DTypeLike = None, + add_file_extension: bool = True, + byte_offset: int = 0, + verbose: bool = False, + **job_kwargs, +): + """ + Save the data of a chunkable object to binary format. + + Note : + time_axis is always 0 (contrary to previous version. + to get time_axis=1 (which is a bad idea) use `write_binary_file_handle()` + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to be saved to binary file + file_path : str or list[str] + The path to the file. + dtype : dtype or None, default: None + Type of the saved data + add_file_extension, bool, default: True + If True, and the file path does not end in "raw", "bin", or "dat" then "raw" is added as an extension. + byte_offset : int, default: 0 + Offset in bytes for the binary file (e.g. to write a header). This is useful in case you want to append data + to an existing file where you wrote a header or other data before. + verbose : bool + This is the verbosity of the ChunkExecutor + {} + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + file_path_list = [file_paths] if not isinstance(file_paths, list) else file_paths + num_segments = chunkable.get_num_segments() + if len(file_path_list) != num_segments: + raise ValueError("'file_paths' must be a list of the same size as the number of segments in the chunkable") + + file_path_list = [Path(file_path) for file_path in file_path_list] + if add_file_extension: + file_path_list = [add_suffix(file_path, ["raw", "bin", "dat"]) for file_path in file_path_list] + + dtype = dtype if dtype is not None else chunkable.get_dtype() + + sample_size_bytes = chunkable.get_sample_size_in_bytes() + + file_path_dict = {segment_index: file_path for segment_index, file_path in enumerate(file_path_list)} + for segment_index, file_path in file_path_dict.items(): + num_samples = chunkable.get_num_samples(segment_index=segment_index) + data_size_bytes = sample_size_bytes * num_samples + file_size_bytes = data_size_bytes + byte_offset + + # Create an empty file with file_size_bytes + with open(file_path, "wb+") as file: + # The previous implementation `file.truncate(file_size_bytes)` was slow on Windows (#3408) + file.seek(file_size_bytes - 1) + file.write(b"\0") + + assert Path(file_path).is_file() + + # use executor (loop or workers) + func = _write_binary_chunk + init_func = _init_binary_worker + init_args = (chunkable, file_path_dict, dtype, byte_offset) + executor = ChunkExecutor( + chunkable, func, init_func, init_args, job_name="write_binary", verbose=verbose, **job_kwargs + ) + executor.run() + + +# used by write_binary + ChunkExecutor +def _init_binary_worker(chunkable, file_path_dict, dtype, byte_offset): + # create a local dict per worker + worker_ctx = {} + worker_ctx["chunkable"] = chunkable + worker_ctx["byte_offset"] = byte_offset + worker_ctx["dtype"] = np.dtype(dtype) + + file_dict = {segment_index: open(file_path, "rb+") for segment_index, file_path in file_path_dict.items()} + worker_ctx["file_dict"] = file_dict + + return worker_ctx + + +# used by write_binary + ChunkExecutor +def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + chunkable = worker_ctx["chunkable"] + dtype = worker_ctx["dtype"] + byte_offset = worker_ctx["byte_offset"] + file = worker_ctx["file_dict"][segment_index] + + sample_size_bytes = chunkable.get_sample_size_in_bytes() + + # Calculate byte offsets for the start frames relative to the entire recording + start_byte = byte_offset + start_frame * sample_size_bytes + + traces = chunkable.get_data(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) + traces = traces.astype(dtype, order="c", copy=False) + + file.seek(start_byte) + file.write(traces.data) + # flush is important!! + file.flush() + + +write_binary.__doc__ = write_binary.__doc__.format(_shared_job_kwargs_doc) + + +# used by write_memory +def _init_memory_worker(chunkable, arrays, shm_names, shapes, dtype): + # create a local dict per worker + worker_ctx = {} + worker_ctx["chunkable"] = chunkable + worker_ctx["dtype"] = np.dtype(dtype) + + if arrays is None: + # create it from share memory name + from multiprocessing.shared_memory import SharedMemory + + arrays = [] + # keep shm alive + worker_ctx["shms"] = [] + for i in range(len(shm_names)): + shm = SharedMemory(shm_names[i]) + worker_ctx["shms"].append(shm) + arr = np.ndarray(shape=shapes[i], dtype=dtype, buffer=shm.buf) + arrays.append(arr) + + worker_ctx["arrays"] = arrays + + return worker_ctx + + +# used by write_memory +def _write_memory_chunk(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + chunkable = worker_ctx["chunkable"] + dtype = worker_ctx["dtype"] + arr = worker_ctx["arrays"][segment_index] + + # apply function + traces = chunkable.get_data(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) + traces = traces.astype(dtype, copy=False) + arr[start_frame:end_frame, :] = traces + + +def write_memory(chunkable, dtype=None, verbose=False, buffer_type="auto", job_name="write_memory", **job_kwargs): + """ + Save the traces into numpy arrays (memory). + try to use the SharedMemory introduce in py3.8 if n_jobs > 1 + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to be saved to memory + dtype : dtype, default: None + Type of the saved data + verbose : bool, default: False + If True, output is verbose (when chunks are used) + buffer_type : "auto" | "numpy" | "sharedmem", + The type of buffer to use for storing the data. + job_name : str, default: "write_memory" + Name of the job + {} + + Returns + --------- + arrays : one array per segment + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + if dtype is None: + dtype = chunkable.get_dtype() + + # create sharedmmep + arrays = [] + shm_names = [] + shms = [] + shapes = [] + + n_jobs = ensure_n_jobs(chunkable, n_jobs=job_kwargs.get("n_jobs", 1)) + if buffer_type == "auto": + if n_jobs > 1: + buffer_type = "sharedmem" + else: + buffer_type = "numpy" + + for segment_index in range(chunkable.get_num_segments()): + shape = chunkable.get_shape(segment_index=segment_index) + shapes.append(shape) + if buffer_type == "sharedmem": + arr, shm = make_shared_array(shape, dtype) + shm_names.append(shm.name) + shms.append(shm) + else: + arr = np.zeros(shape, dtype=dtype) + shms.append(None) + arrays.append(arr) + + # use executor (loop or workers) + func = _write_memory_chunk + init_func = _init_memory_worker + if n_jobs > 1: + init_args = (chunkable, None, shm_names, shapes, dtype) + else: + init_args = (chunkable, arrays, None, None, dtype) + + executor = ChunkExecutor(chunkable, func, init_func, init_args, verbose=verbose, job_name=job_name, **job_kwargs) + executor.run() + + return arrays, shms + + +write_memory.__doc__ = write_memory.__doc__.format(_shared_job_kwargs_doc) + + +def get_random_sample_slices( + chunkable: "ChunkableMixin", + method="full_random", + num_chunks_per_segment=20, + chunk_duration="500ms", + chunk_size=None, + margin_frames=0, + seed=None, +): + """ + Get random slice of a chunkable object across segments. + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to get random chunks from + method : "full_random" + The method used to get random slices. + * "full_random" : legacy method, used until version 0.101.0, there is no constrain on slices + and they can overlap. + num_chunks_per_segment : int, default: 20 + Number of chunks per segment + chunk_duration : str | float | None, default "500ms" + The duration of each chunk in 's' or 'ms' + chunk_size : int | None + Size of a chunk in number of frames. This is used only if chunk_duration is None. + This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead. + concatenated : bool, default: True + If True chunk are concatenated along time axis + seed : int, default: None + Random seed + margin_frames : int, default: 0 + Margin in number of frames to avoid edge effects + + Returns + ------- + chunk_list : np.array + Array of concatenate chunks per segment + + + """ + # TODO: if segment have differents length make another sampling that dependant on the length of the segment + # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY + # And randomize the number of chunk per segment weighted by segment duration + + if method == "full_random": + if chunk_size is None: + if chunk_duration is not None: + chunk_size = chunk_duration_to_chunk_size(chunk_duration, chunkable) + else: + raise ValueError("get_random_sample_slices need chunk_size or chunk_duration") + + # check chunk size + num_segments = chunkable.get_num_segments() + for segment_index in range(num_segments): + chunk_size_limit = chunkable.get_num_samples(segment_index) - 2 * margin_frames + if chunk_size > chunk_size_limit: + chunk_size = chunk_size_limit - 1 + warnings.warn( + f"chunk_size is greater than the number " + f"of samples for segment index {segment_index}. " + f"Using {chunk_size}." + ) + rng = np.random.default_rng(seed) + slices = [] + low = margin_frames + size = num_chunks_per_segment + for segment_index in range(num_segments): + num_frames = chunkable.get_num_samples(segment_index) + high = num_frames - chunk_size - margin_frames + # here we set endpoint to True, because the this represents the start of the + # chunk, and should be inclusive + random_starts = rng.integers(low=low, high=high, size=size, endpoint=True) + random_starts = np.sort(random_starts) + slices += [(segment_index, start_frame, (start_frame + chunk_size)) for start_frame in random_starts] + else: + raise ValueError(f"get_random_sample_slices : wrong method {method}") + + return slices + + +def get_chunks(chunkable: "ChunkableMixin", concatenated=True, get_data_kwargs=None, **random_slices_kwargs): + """ + Extract random chunks across segments. + + Internally, it uses `get_random_sample_slices()` and retrieves the traces chunk as a list + or a concatenated unique array. + + Please read `get_random_sample_slices()` for more details on parameters. + + # TODO: handle this in recording tools: + return * will be get_data_kwargs + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to get random chunks from + return_scaled : bool | None, default: None + DEPRECATED. Use return_in_uV instead. + return_in_uV : bool, default: False + If True and the chunkable has scaling (gain_to_uV and offset_to_uV properties), + traces are scaled to uV + num_chunks_per_segment : int, default: 20 + Number of chunks per segment + concatenated : bool, default: True + If True chunk are concatenated along time axis + **random_slices_kwargs : dict + Options transmited to get_random_sample_slices(), please read documentation from this + function for more details. + + Returns + ------- + chunk_list : np.array | list of np.array + Array of concatenate chunks per segment + """ + slices = get_random_sample_slices(chunkable, **random_slices_kwargs) + + chunk_list = [] + get_data_kwargs = get_data_kwargs if get_data_kwargs is not None else {} + for segment_index, start_frame, end_frame in slices: + traces_chunk = chunkable.get_data( + start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, **get_data_kwargs + ) + chunk_list.append(traces_chunk) + + if concatenated: + return np.concatenate(chunk_list, axis=0) + else: + return chunk_list diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index db29a4b3cd..7441615e9c 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -48,9 +48,9 @@ def __init__(self, parent_recording, start_frame=None, end_frame=None): ) # link recording segment - parent_segment = parent_recording._recording_segments[0] + parent_segment = parent_recording.segments[0] sub_segment = FrameSliceRecordingSegment(parent_segment, start_frame=int(start_frame), end_frame=int(end_frame)) - self.add_recording_segment(sub_segment) + self.add_segment(sub_segment) # copy properties and annotations parent_recording.copy_metadata(self) diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index 629db2f36d..f32e21be5e 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -76,9 +76,9 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike BaseSorting.__init__(self, sampling_frequency=parent_sorting.get_sampling_frequency(), unit_ids=unit_ids) # link sorting segment - parent_segment = parent_sorting._sorting_segments[0] + parent_segment = parent_sorting.segments[0] sub_segment = FrameSliceSortingSegment(parent_segment, start_frame, end_frame) - self.add_sorting_segment(sub_segment) + self.add_segment(sub_segment) # copy properties and annotations parent_sorting.copy_metadata(self) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index db6dff3f2d..c3c8e3b9c4 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -393,7 +393,7 @@ def __init__( # We need to add the sorting segments for segment_index in range(sorting.get_num_segments()): segment = SpikeVectorSortingSegment(self._cached_spike_vector, segment_index, unit_ids=self.unit_ids) - self.add_sorting_segment(segment) + self.add_segment(segment) if self.refractory_period_ms is not None: self.clean_refractory_period() @@ -1141,7 +1141,7 @@ def __init__( unit_ids=unit_ids, t_start=None, ) - self.add_sorting_segment(segment) + self.add_segment(segment) self._kwargs = { "num_units": num_units, @@ -1327,7 +1327,7 @@ def __init__( segments_seeds[i], strategy, ) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = { "num_channels": num_channels, @@ -1994,9 +1994,7 @@ def __init__( amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None - parent_recording_segment = ( - None if parent_recording is None else parent_recording._recording_segments[segment_index] - ) + parent_recording_segment = None if parent_recording is None else parent_recording.segments[segment_index] recording_segment = InjectTemplatesRecordingSegment( self.sampling_frequency, self.dtype, @@ -2008,7 +2006,7 @@ def __init__( parent_recording_segment, num_samples[segment_index], ) - self.add_recording_segment(recording_segment) + self.add_segment(recording_segment) # to discuss: maybe we could set json serializability to False always # because templates could be large! diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index c63d8d590e..a7e4e25e2d 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -7,7 +7,6 @@ import platform import os import warnings -from spikeinterface.core.core_tools import convert_string_to_bytes, convert_bytes_to_str, convert_seconds_to_str import sys from tqdm.auto import tqdm @@ -17,6 +16,8 @@ import threading from threadpoolctl import threadpool_limits +from spikeinterface.core.core_tools import convert_string_to_bytes, convert_bytes_to_str, convert_seconds_to_str + _shared_job_kwargs_doc = """**job_kwargs : keyword arguments for parallel processing: * chunk_duration or chunk_size or chunk_memory or total_memory @@ -217,16 +218,16 @@ def divide_segment_into_chunks(num_frames, chunk_size): return chunks -def divide_recording_into_chunks(recording, chunk_size): - recording_slices = [] +def divide_extractor_into_chunks(recording, chunk_size): + slices = [] for segment_index in range(recording.get_num_segments()): num_frames = recording.get_num_samples(segment_index) chunks = divide_segment_into_chunks(num_frames, chunk_size) - recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) - return recording_slices + slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + return slices -def ensure_n_jobs(recording, n_jobs=1): +def ensure_n_jobs(extractor, n_jobs=1): if n_jobs == -1: n_jobs = os.cpu_count() elif n_jobs == 0: @@ -244,19 +245,19 @@ def ensure_n_jobs(recording, n_jobs=1): print(f"Python {sys.version} does not support parallel processing") n_jobs = 1 - if not recording.check_if_memory_serializable(): + if not extractor.check_if_memory_serializable(): if n_jobs != 1: raise RuntimeError( - "Recording is not serializable to memory and can't be processed in parallel. " + "Extractor is not serializable to memory and can't be processed in parallel. " "You can use the `rec = recording.save(folder=...)` function or set 'n_jobs' to 1." ) return n_jobs -def chunk_duration_to_chunk_size(chunk_duration, recording): +def chunk_duration_to_chunk_size(chunk_duration, chunkable: "ChunkableMixin"): if isinstance(chunk_duration, float): - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + chunk_size = int(chunk_duration * chunkable.get_sampling_frequency()) elif isinstance(chunk_duration, str): if chunk_duration.endswith("ms"): chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 @@ -264,17 +265,23 @@ def chunk_duration_to_chunk_size(chunk_duration, recording): chunk_duration = float(chunk_duration.replace("s", "")) else: raise ValueError("chunk_duration must ends with s or ms") - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + chunk_size = int(chunk_duration * chunkable.get_sampling_frequency()) else: raise ValueError("chunk_duration must be str or float") return chunk_size def ensure_chunk_size( - recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs + chunkable: "ChunkableMixin", + total_memory=None, + chunk_size=None, + chunk_memory=None, + chunk_duration=None, + n_jobs=1, + **other_kwargs, ): """ - "chunk_size" is the traces.shape[0] for each worker. + "chunk_size" is the number of samples for each worker. Flexible chunk_size setter with 3 ways: * "chunk_size" : is the length in sample for each chunk independently of channel count and dtype. @@ -305,24 +312,20 @@ def ensure_chunk_size( assert total_memory is None # set by memory per worker size chunk_memory = convert_string_to_bytes(chunk_memory) - n_bytes = np.dtype(recording.get_dtype()).itemsize - num_channels = recording.get_num_channels() - chunk_size = int(chunk_memory / (num_channels * n_bytes)) + chunk_size = int(chunk_memory / chunkable.get_sample_size_in_bytes()) elif total_memory is not None: # clip by total memory size - n_jobs = ensure_n_jobs(recording, n_jobs=n_jobs) + n_jobs = ensure_n_jobs(chunkable, n_jobs=n_jobs) total_memory = convert_string_to_bytes(total_memory) - n_bytes = np.dtype(recording.get_dtype()).itemsize - num_channels = recording.get_num_channels() - chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs)) + chunk_size = int(total_memory / (chunkable.get_sample_size_in_bytes() * n_jobs)) elif chunk_duration is not None: - chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) + chunk_size = chunk_duration_to_chunk_size(chunk_duration, chunkable) else: # Edge case to define single chunk per segment for n_jobs=1. # All chunking parameters equal None mean single chunk per segment if n_jobs == 1: - num_segments = recording.get_num_segments() - samples_in_larger_segment = max([recording.get_num_samples(segment) for segment in range(num_segments)]) + num_segments = chunkable.get_num_segments() + samples_in_larger_segment = max([chunkable.get_num_samples(segment) for segment in range(num_segments)]) chunk_size = samples_in_larger_segment else: raise ValueError("For n_jobs >1 you must specify total_memory or chunk_size or chunk_memory") @@ -330,9 +333,9 @@ def ensure_chunk_size( return chunk_size -class ChunkRecordingExecutor: +class ChunkExecutor: """ - Core class for parallel processing to run a "function" over chunks on a recording. + Core class for parallel processing to run a "function" over chunks on a chunkable extractor. It supports running a function: * in loop with chunk processing (low RAM usage) @@ -344,8 +347,8 @@ class ChunkRecordingExecutor: Parameters ---------- - recording : RecordingExtractor - The recording to be processed + chunkable : ChunkableMixin + The chunkable object to be processed. func : function Function that runs on each chunk init_func : function @@ -393,7 +396,7 @@ class ChunkRecordingExecutor: def __init__( self, - recording, + chunkable: "ChunkableMixin", func, init_func, init_args, @@ -412,14 +415,15 @@ def __init__( max_threads_per_worker=1, need_worker_index=False, ): - self.recording = recording + self.chunkable = chunkable self.func = func self.init_func = init_func self.init_args = init_args if pool_engine == "process": if mp_context is None: - mp_context = recording.get_preferred_mp_context() + if hasattr(chunkable, "get_preferred_mp_context"): + mp_context = chunkable.get_preferred_mp_context() if mp_context is not None and platform.system() == "Windows": assert mp_context != "fork", "'fork' mp_context not supported on Windows!" elif mp_context == "fork" and platform.system() == "Darwin": @@ -433,9 +437,8 @@ def __init__( self.handle_returns = handle_returns self.gather_func = gather_func - self.n_jobs = ensure_n_jobs(recording, n_jobs=n_jobs) - self.chunk_size = ensure_chunk_size( - recording, + self.n_jobs = ensure_n_jobs(self.chunkable, n_jobs=n_jobs) + self.chunk_size = self.ensure_chunk_size( total_memory=total_memory, chunk_size=chunk_size, chunk_memory=chunk_memory, @@ -450,9 +453,9 @@ def __init__( self.need_worker_index = need_worker_index if verbose: - chunk_memory = self.chunk_size * recording.get_num_channels() * np.dtype(recording.get_dtype()).itemsize + chunk_memory = self.get_chunk_memory() total_memory = chunk_memory * self.n_jobs - chunk_duration = self.chunk_size / recording.get_sampling_frequency() + chunk_duration = self.chunk_size / chunkable.sampling_frequency chunk_memory_str = convert_bytes_to_str(chunk_memory) total_memory_str = convert_bytes_to_str(total_memory) chunk_duration_str = convert_seconds_to_str(chunk_duration) @@ -467,13 +470,24 @@ def __init__( f"chunk_duration={chunk_duration_str}", ) - def run(self, recording_slices=None): + def get_chunk_memory(self): + return self.chunk_size * self.chunkable.get_sample_size_in_bytes() + + def ensure_chunk_size( + self, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs + ): + return ensure_chunk_size( + self.chunkable, total_memory, chunk_size, chunk_memory, chunk_duration, n_jobs, **other_kwargs + ) + + def run(self, slices=None): """ Runs the defined jobs. """ - if recording_slices is None: - recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size) + if slices is None: + # TODO: rename + slices = divide_extractor_into_chunks(self.chunkable, self.chunk_size) if self.handle_returns: returns = [] @@ -482,15 +496,13 @@ def run(self, recording_slices=None): if self.n_jobs == 1: if self.progress_bar: - recording_slices = tqdm( - recording_slices, desc=f"{self.job_name} (no parallelization)", total=len(recording_slices) - ) + slices = tqdm(slices, desc=f"{self.job_name} (no parallelization)", total=len(slices)) worker_dict = self.init_func(*self.init_args) if self.need_worker_index: worker_dict["worker_index"] = 0 - for segment_index, frame_start, frame_stop in recording_slices: + for segment_index, frame_start, frame_stop in slices: res = self.func(segment_index, frame_start, frame_stop, worker_dict) if self.handle_returns: returns.append(res) @@ -498,7 +510,7 @@ def run(self, recording_slices=None): self.gather_func(res) else: - n_jobs = min(self.n_jobs, len(recording_slices)) + n_jobs = min(self.n_jobs, len(slices)) if self.pool_engine == "process": @@ -526,11 +538,11 @@ def run(self, recording_slices=None): array_pid, ), ) as executor: - results = executor.map(process_function_wrapper, recording_slices) + results = executor.map(process_function_wrapper, slices) if self.progress_bar: results = tqdm( - results, desc=f"{self.job_name} (workers: {n_jobs} processes)", total=len(recording_slices) + results, desc=f"{self.job_name} (workers: {n_jobs} processes)", total=len(slices) ) for res in results: @@ -549,7 +561,7 @@ def run(self, recording_slices=None): if self.progress_bar: # here the tqdm threading do not work (maybe collision) so we need to create a pbar # before thread spawning - pbar = tqdm(desc=f"{self.job_name} (workers: {n_jobs} threads)", total=len(recording_slices)) + pbar = tqdm(desc=f"{self.job_name} (workers: {n_jobs} threads)", total=len(slices)) if self.need_worker_index: lock = threading.Lock() @@ -570,8 +582,8 @@ def run(self, recording_slices=None): ), ) as executor: - recording_slices2 = [(thread_local_data,) + tuple(args) for args in recording_slices] - results = executor.map(thread_function_wrapper, recording_slices2) + slices2 = [(thread_local_data,) + tuple(args) for args in slices] + results = executor.map(thread_function_wrapper, slices2) for res in results: if self.progress_bar: diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1cec886d95..77d01e48d3 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -11,7 +11,7 @@ import numpy as np from spikeinterface.core import BaseRecording, get_chunk_with_margin -from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs, _shared_job_kwargs_doc +from spikeinterface.core.job_tools import ChunkExecutor, fix_job_kwargs, _shared_job_kwargs_doc from spikeinterface.core import get_channel_distances @@ -533,7 +533,7 @@ def run_node_pipeline( names=None, verbose=False, skip_after_n_peaks=None, - recording_slices=None, + slices=None, ): """ Machinery to compute in parallel operations on peaks and traces. @@ -585,7 +585,7 @@ def run_node_pipeline( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. - recording_slices : None | list[tuple] + slices : None | list[tuple] Optionaly give a list of slices to run the pipeline only on some chunks of the recording. It must be a list of (segment_index, frame_start, frame_stop). If None (default), the function iterates over the entire duration of the recording. @@ -616,7 +616,7 @@ def run_node_pipeline( init_args = (recording, nodes, skip_after_n_peaks_per_worker) - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, _compute_peak_pipeline_chunk, _init_peak_pipeline, @@ -627,7 +627,7 @@ def run_node_pipeline( **job_kwargs, ) - processor.run(recording_slices=recording_slices) + processor.run(slices=slices) outs = gather_func.finalize_buffers(squeeze_output=squeeze_output) return outs @@ -650,7 +650,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c nodes = worker_ctx["nodes"] skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"] - recording_segment = recording._recording_segments[segment_index] + recording_segment = recording.segments[segment_index] retrievers = find_parents_of_type(nodes, (SpikeRetriever, PeakRetriever)) # get peak slices once for all retrievers peak_slice_by_retriever = {} diff --git a/src/spikeinterface/core/npzsortingextractor.py b/src/spikeinterface/core/npzsortingextractor.py index b8e7357e8c..45b2042792 100644 --- a/src/spikeinterface/core/npzsortingextractor.py +++ b/src/spikeinterface/core/npzsortingextractor.py @@ -30,7 +30,7 @@ def __init__(self, file_path): spike_indexes = npz[f"spike_indexes_seg{seg_index}"] spike_labels = npz[f"spike_labels_seg{seg_index}"] sorting_segment = NpzSortingSegment(spike_indexes, spike_labels) - self.add_sorting_segment(sorting_segment) + self.add_segment(sorting_segment) self._kwargs = {"file_path": str(Path(file_path).absolute())} diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 019759797b..12c0907483 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -14,7 +14,7 @@ ) from .basesorting import minimum_spike_dtype from .core_tools import make_shared_array -from .recording_tools import write_memory_recording +from .chunkable_tools import write_memory from multiprocessing.shared_memory import SharedMemory @@ -73,7 +73,7 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N else: t_start = t_starts[i] rec_segment = NumpyRecordingSegment(traces, sampling_frequency, t_start) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = { "traces_list": traces_list, @@ -83,7 +83,7 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N @staticmethod def from_recording(source_recording, **job_kwargs): - traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs) + traces_list, shms = write_memory(source_recording, dtype=None, **job_kwargs) t_starts = source_recording._get_t_starts() @@ -187,7 +187,7 @@ def __init__( t_start = t_starts[i] rec_segment = NumpyRecordingSegment(traces, sampling_frequency, t_start) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = { "shm_names": shm_names, @@ -201,7 +201,7 @@ def __init__( } def __del__(self): - self._recording_segments = [] + self.segments = [] for shm in self.shms: shm.close() if self.main_shm_owner: @@ -209,7 +209,7 @@ def __del__(self): @staticmethod def from_recording(source_recording, **job_kwargs): - traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs) + traces_list, shms = write_memory(source_recording, buffer_type="sharedmem", **job_kwargs) t_starts = source_recording._get_t_starts() @@ -267,7 +267,7 @@ def __init__(self, spikes, sampling_frequency, unit_ids): nseg = spikes[-1]["segment_index"] + 1 for segment_index in range(nseg): - self.add_sorting_segment(SpikeVectorSortingSegment(spikes, segment_index, unit_ids)) + self.add_segment(SpikeVectorSortingSegment(spikes, segment_index, unit_ids)) # important trick : the cache is already spikes vector self._cached_spike_vector = spikes @@ -519,7 +519,7 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ nseg = self.shm_spikes[-1]["segment_index"] + 1 for segment_index in range(nseg): - self.add_sorting_segment(SpikeVectorSortingSegment(self.shm_spikes, segment_index, unit_ids)) + self.add_segment(SpikeVectorSortingSegment(self.shm_spikes, segment_index, unit_ids)) # important trick : the cache is already spikes vector self._cached_spike_vector = self.shm_spikes diff --git a/src/spikeinterface/core/old_api_utils.py b/src/spikeinterface/core/old_api_utils.py index 53ef736208..a3406388e4 100644 --- a/src/spikeinterface/core/old_api_utils.py +++ b/src/spikeinterface/core/old_api_utils.py @@ -190,7 +190,7 @@ def __init__(self, oldapi_recording_extractor): # add old recording as a recording segment recording_segment = OldToNewRecordingSegment(oldapi_recording_extractor) - self.add_recording_segment(recording_segment) + self.add_segment(recording_segment) self.set_channel_locations(oldapi_recording_extractor.get_channel_locations()) # add old properties @@ -267,7 +267,7 @@ def __init__(self, oldapi_sorting_extractor): ) sorting_segment = OldToNewSortingSegment(oldapi_sorting_extractor) - self.add_sorting_segment(sorting_segment) + self.add_segment(sorting_segment) self._serializability["memory"] = False self._serializability["json"] = False diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index fd95b11e6a..d808bc69d6 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -4,7 +4,6 @@ import warnings from pathlib import Path import os -import mmap import tqdm @@ -13,15 +12,19 @@ from .core_tools import add_suffix, make_shared_array from .job_tools import ( ensure_chunk_size, - ensure_n_jobs, divide_segment_into_chunks, fix_job_kwargs, - ChunkRecordingExecutor, + ChunkExecutor, _shared_job_kwargs_doc, - chunk_duration_to_chunk_size, split_job_kwargs, ) +from .chunkable_tools import get_random_sample_slices, get_chunks + +# for back-compatibility imports +from .chunkable_tools import write_binary as write_binary_recording +from .chunkable_tools import write_memory as write_memory_recording + def read_binary_recording(file, num_channels, dtype, time_axis=0, offset=0): """ @@ -53,124 +56,11 @@ def read_binary_recording(file, num_channels, dtype, time_axis=0, offset=0): return samples -# used by write_binary_recording + ChunkRecordingExecutor -def _init_binary_worker(recording, file_path_dict, dtype, byte_offest): - # create a local dict per worker - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["byte_offset"] = byte_offest - worker_ctx["dtype"] = np.dtype(dtype) - - file_dict = {segment_index: open(file_path, "rb+") for segment_index, file_path in file_path_dict.items()} - worker_ctx["file_dict"] = file_dict - - return worker_ctx - - -def write_binary_recording( - recording: "BaseRecording", - file_paths: list[Path | str] | Path | str, - dtype: np.typing.DTypeLike = None, - add_file_extension: bool = True, - byte_offset: int = 0, - verbose: bool = False, - **job_kwargs, -): - """ - Save the trace of a recording extractor in several binary .dat format. - - Note : - time_axis is always 0 (contrary to previous version. - to get time_axis=1 (which is a bad idea) use `write_binary_recording_file_handle()` - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor object to be saved in .dat format - file_path : str or list[str] - The path to the file. - dtype : dtype or None, default: None - Type of the saved data - add_file_extension, bool, default: True - If True, and the file path does not end in "raw", "bin", or "dat" then "raw" is added as an extension. - byte_offset : int, default: 0 - Offset in bytes for the binary file (e.g. to write a header). This is useful in case you want to append data - to an existing file where you wrote a header or other data before. - verbose : bool - This is the verbosity of the ChunkRecordingExecutor - {} - """ - job_kwargs = fix_job_kwargs(job_kwargs) - - file_path_list = [file_paths] if not isinstance(file_paths, list) else file_paths - num_segments = recording.get_num_segments() - if len(file_path_list) != num_segments: - raise ValueError("'file_paths' must be a list of the same size as the number of segments in the recording") - - file_path_list = [Path(file_path) for file_path in file_path_list] - if add_file_extension: - file_path_list = [add_suffix(file_path, ["raw", "bin", "dat"]) for file_path in file_path_list] - - dtype = dtype if dtype is not None else recording.get_dtype() - - dtype_size_bytes = np.dtype(dtype).itemsize - num_channels = recording.get_num_channels() - - file_path_dict = {segment_index: file_path for segment_index, file_path in enumerate(file_path_list)} - for segment_index, file_path in file_path_dict.items(): - num_frames = recording.get_num_frames(segment_index=segment_index) - data_size_bytes = dtype_size_bytes * num_frames * num_channels - file_size_bytes = data_size_bytes + byte_offset - - # Create an empty file with file_size_bytes - with open(file_path, "wb+") as file: - # The previous implementation `file.truncate(file_size_bytes)` was slow on Windows (#3408) - file.seek(file_size_bytes - 1) - file.write(b"\0") - - assert Path(file_path).is_file() - - # use executor (loop or workers) - func = _write_binary_chunk - init_func = _init_binary_worker - init_args = (recording, file_path_dict, dtype, byte_offset) - executor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name="write_binary_recording", verbose=verbose, **job_kwargs - ) - executor.run() - - -# used by write_binary_recording + ChunkRecordingExecutor -def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): - # recover variables of the worker - recording = worker_ctx["recording"] - dtype = worker_ctx["dtype"] - byte_offset = worker_ctx["byte_offset"] - file = worker_ctx["file_dict"][segment_index] - - num_channels = recording.get_num_channels() - dtype_size_bytes = np.dtype(dtype).itemsize - - # Calculate byte offsets for the start frames relative to the entire recording - start_byte = byte_offset + start_frame * num_channels * dtype_size_bytes - - traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) - traces = traces.astype(dtype, order="c", copy=False) - - file.seek(start_byte) - file.write(traces.data) - # flush is important!! - file.flush() - - -write_binary_recording.__doc__ = write_binary_recording.__doc__.format(_shared_job_kwargs_doc) - - -def write_binary_recording_file_handle( +def write_binary_file_handle( recording, file_handle=None, time_axis=0, dtype=None, byte_offset=0, verbose=False, **job_kwargs ): """ - Old variant version of write_binary_recording with one file handle. + Old variant version of write_binary with one file handle. Can be useful in some case ??? Not used anymore at the moment. @@ -210,121 +100,6 @@ def write_binary_recording_file_handle( file_handle.write(traces.tobytes()) -# used by write_memory_recording -def _init_memory_worker(recording, arrays, shm_names, shapes, dtype): - # create a local dict per worker - worker_ctx = {} - if isinstance(recording, dict): - from spikeinterface.core import load - - worker_ctx["recording"] = load(recording) - else: - worker_ctx["recording"] = recording - - worker_ctx["dtype"] = np.dtype(dtype) - - if arrays is None: - # create it from share memory name - from multiprocessing.shared_memory import SharedMemory - - arrays = [] - # keep shm alive - worker_ctx["shms"] = [] - for i in range(len(shm_names)): - shm = SharedMemory(shm_names[i]) - worker_ctx["shms"].append(shm) - arr = np.ndarray(shape=shapes[i], dtype=dtype, buffer=shm.buf) - arrays.append(arr) - - worker_ctx["arrays"] = arrays - - return worker_ctx - - -# used by write_memory_recording -def _write_memory_chunk(segment_index, start_frame, end_frame, worker_ctx): - # recover variables of the worker - recording = worker_ctx["recording"] - dtype = worker_ctx["dtype"] - arr = worker_ctx["arrays"][segment_index] - - # apply function - traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) - traces = traces.astype(dtype, copy=False) - arr[start_frame:end_frame, :] = traces - - -def write_memory_recording(recording, dtype=None, verbose=False, buffer_type="auto", **job_kwargs): - """ - Save the traces into numpy arrays (memory). - try to use the SharedMemory introduce in py3.8 if n_jobs > 1 - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor object to be saved in .dat format - dtype : dtype, default: None - Type of the saved data - verbose : bool, default: False - If True, output is verbose (when chunks are used) - buffer_type : "auto" | "numpy" | "sharedmem" - {} - - Returns - --------- - arrays : one array per segment - """ - job_kwargs = fix_job_kwargs(job_kwargs) - - if dtype is None: - dtype = recording.get_dtype() - - # create sharedmmep - arrays = [] - shm_names = [] - shms = [] - shapes = [] - - n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get("n_jobs", 1)) - if buffer_type == "auto": - if n_jobs > 1: - buffer_type = "sharedmem" - else: - buffer_type = "numpy" - - for segment_index in range(recording.get_num_segments()): - num_frames = recording.get_num_samples(segment_index) - num_channels = recording.get_num_channels() - shape = (num_frames, num_channels) - shapes.append(shape) - if buffer_type == "sharedmem": - arr, shm = make_shared_array(shape, dtype) - shm_names.append(shm.name) - shms.append(shm) - else: - arr = np.zeros(shape, dtype=dtype) - shms.append(None) - arrays.append(arr) - - # use executor (loop or workers) - func = _write_memory_chunk - init_func = _init_memory_worker - if n_jobs > 1: - init_args = (recording, None, shm_names, shapes, dtype) - else: - init_args = (recording, arrays, None, None, dtype) - - executor = ChunkRecordingExecutor( - recording, func, init_func, init_args, verbose=verbose, job_name="write_memory_recording", **job_kwargs - ) - executor.run() - - return arrays, shms - - -write_memory_recording.__doc__ = write_memory_recording.__doc__.format(_shared_job_kwargs_doc) - - def write_to_h5_dataset_format( recording, dataset_path, @@ -465,91 +240,6 @@ def write_to_h5_dataset_format( return save_path -def get_random_recording_slices( - recording, - method="full_random", - num_chunks_per_segment=20, - chunk_duration="500ms", - chunk_size=None, - margin_frames=0, - seed=None, -): - """ - Get random slice of a recording across segments. - - This is used for instance in get_noise_levels() and get_random_data_chunks() to estimate noise on traces. - - Parameters - ---------- - recording : BaseRecording - The recording to get random chunks from - method : "full_random" - The method used to get random slices. - * "full_random" : legacy method, used until version 0.101.0, there is no constrain on slices - and they can overlap. - num_chunks_per_segment : int, default: 20 - Number of chunks per segment - chunk_duration : str | float | None, default "500ms" - The duration of each chunk in 's' or 'ms' - chunk_size : int | None - Size of a chunk in number of frames. This is used only if chunk_duration is None. - This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead. - concatenated : bool, default: True - If True chunk are concatenated along time axis - seed : int, default: None - Random seed - margin_frames : int, default: 0 - Margin in number of frames to avoid edge effects - - Returns - ------- - chunk_list : np.array - Array of concatenate chunks per segment - - - """ - # TODO: if segment have differents length make another sampling that dependant on the length of the segment - # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY - # And randomize the number of chunk per segment weighted by segment duration - - if method == "full_random": - if chunk_size is None: - if chunk_duration is not None: - chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) - else: - raise ValueError("get_random_recording_slices need chunk_size or chunk_duration") - - # check chunk size - num_segments = recording.get_num_segments() - for segment_index in range(num_segments): - chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames - if chunk_size > chunk_size_limit: - chunk_size = chunk_size_limit - 1 - warnings.warn( - f"chunk_size is greater than the number " - f"of samples for segment index {segment_index}. " - f"Using {chunk_size}." - ) - rng = np.random.default_rng(seed) - recording_slices = [] - low = margin_frames - size = num_chunks_per_segment - for segment_index in range(num_segments): - num_frames = recording.get_num_frames(segment_index) - high = num_frames - chunk_size - margin_frames - # here we set endpoint to True, because the this represents the start of the - # chunk, and should be inclusive - random_starts = rng.integers(low=low, high=high, size=size, endpoint=True) - random_starts = np.sort(random_starts) - recording_slices += [ - (segment_index, start_frame, (start_frame + chunk_size)) for start_frame in random_starts - ] - else: - raise ValueError(f"get_random_recording_slices : wrong method {method}") - - return recording_slices - - def get_random_data_chunks( recording, return_scaled=None, return_in_uV=False, concatenated=True, **random_slices_kwargs ): @@ -593,22 +283,12 @@ def get_random_data_chunks( ) return_in_uV = return_scaled - recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) - - chunk_list = [] - for segment_index, start_frame, end_frame in recording_slices: - traces_chunk = recording.get_traces( - start_frame=start_frame, - end_frame=end_frame, - segment_index=segment_index, - return_in_uV=return_in_uV, - ) - chunk_list.append(traces_chunk) - - if concatenated: - return np.concatenate(chunk_list, axis=0) - else: - return chunk_list + return get_chunks( + recording, + concatenated=concatenated, + get_data_kwargs=dict(return_in_uV=return_in_uV), + **random_slices_kwargs, + ) def get_channel_distances(recording): @@ -725,7 +405,7 @@ def get_noise_levels( force_recompute : bool If True, noise levels are recomputed even if they are already stored in the recording extractor random_slices_kwargs : dict - Options transmited to get_random_recording_slices(), please read documentation from this + Options transmitted to get_random_sample_slices(), please read documentation from this function for more details. {} @@ -760,7 +440,7 @@ def get_noise_levels( msg = ( "get_noise_levels(recording, num_chunks_per_segment=20) is deprecated\n" "Now, you need to use get_noise_levels(recording, random_slices_kwargs=dict(num_chunks_per_segment=20, chunk_size=1000))\n" - "Please read get_random_recording_slices() documentation for more options." + "Please read get_random_sample_slices() documentation for more options." ) # if the user use both the old and the new behavior then an error is raised assert len(random_slices_kwargs) == 0, msg @@ -769,7 +449,7 @@ def get_noise_levels( if "chunk_size" in job_kwargs: random_slices_kwargs["chunk_size"] = job_kwargs["chunk_size"] - recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) + slices = get_random_sample_slices(recording, **random_slices_kwargs) noise_levels_chunks = [] @@ -779,7 +459,7 @@ def append_noise_chunk(res): func = _noise_level_chunk init_func = _noise_level_chunk_init init_args = (recording, return_in_uV, method) - executor = ChunkRecordingExecutor( + executor = ChunkExecutor( recording, func, init_func, @@ -789,7 +469,7 @@ def append_noise_chunk(res): gather_func=append_noise_chunk, **job_kwargs, ) - executor.run(recording_slices=recording_slices) + executor.run(slices=slices) noise_levels_chunks = np.stack(noise_levels_chunks) noise_levels = np.mean(noise_levels_chunks, axis=0) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index a5b7201352..45eec17eb2 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -64,9 +64,9 @@ def __init__(self, recording_list, sampling_frequency_max_diff=0): rec0.copy_metadata(self) for rec in recording_list: - for parent_segment in rec._recording_segments: + for parent_segment in rec.segments: rec_seg = ProxyAppendRecordingSegment(parent_segment) - self.add_recording_segment(rec_seg) + self.add_segment(rec_seg) self._kwargs = {"recording_list": recording_list, "sampling_frequency_max_diff": sampling_frequency_max_diff} @@ -83,7 +83,7 @@ def get_traces(self, *args, **kwargs): return self.parent_segment.get_traces(*args, **kwargs) -append_recordings = define_function_from_class(source_class=AppendSegmentRecording, name="append_segment_recording") +append_recordings = define_function_from_class(source_class=AppendSegmentRecording, name="add_segment_recording") class ConcatenateSegmentRecording(BaseRecording): @@ -120,7 +120,7 @@ def __init__(self, recording_list, ignore_times=True, sampling_frequency_max_dif parent_segments = [] for rec in recording_list: - for parent_segment in rec._recording_segments: + for parent_segment in rec.segments: time_kwargs = parent_segment.get_times_kwargs() if not ignore_times: assert time_kwargs["time_vector"] is None, ( @@ -135,7 +135,7 @@ def __init__(self, recording_list, ignore_times=True, sampling_frequency_max_dif rec_seg = ProxyConcatenateRecordingSegment( parent_segments, one_rec.get_sampling_frequency(), ignore_times=ignore_times ) - self.add_recording_segment(rec_seg) + self.add_segment(rec_seg) self._kwargs = { "recording_list": recording_list, @@ -241,8 +241,8 @@ def __init__(self, recording: BaseRecording, segment_indices: int | list[int]): ), f"'segment_index' must be between 0 and {num_segments - 1}" for segment_index in segment_indices: - rec_seg = recording._recording_segments[segment_index] - self.add_recording_segment(rec_seg) + rec_seg = recording.segments[segment_index] + self.add_segment(rec_seg) self._parent = recording self._kwargs = {"recording": recording, "segment_indices": segment_indices} @@ -303,9 +303,9 @@ def __init__(self, sorting_list, sampling_frequency_max_diff=0): sorting0.copy_metadata(self) for sorting in sorting_list: - for parent_segment in sorting._sorting_segments: + for parent_segment in sorting.segments: sorting_seg = ProxyAppendSortingSegment(parent_segment) - self.add_sorting_segment(sorting_seg) + self.add_segment(sorting_seg) self._kwargs = {"sorting_list": sorting_list, "sampling_frequency_max_diff": sampling_frequency_max_diff} @@ -385,7 +385,7 @@ def __init__(self, sorting_list, total_samples_list=None, ignore_times=True, sam parent_segments = [] parent_num_samples = [] for sorting_i, sorting in enumerate(sorting_list): - for segment_i, parent_segment in enumerate(sorting._sorting_segments): + for segment_i, parent_segment in enumerate(sorting.segments): # Check t_start is not assigned segment_t_start = parent_segment._t_start if not ignore_times: @@ -421,7 +421,7 @@ def __init__(self, sorting_list, total_samples_list=None, ignore_times=True, sam sorting_seg = ProxyConcatenateSortingSegment( parent_segments, parent_num_samples, one_sorting.get_sampling_frequency() ) - self.add_sorting_segment(sorting_seg) + self.add_segment(sorting_seg) # Assign concatenated recording if possible if all_has_recording: @@ -439,7 +439,7 @@ def __init__(self, sorting_list, total_samples_list=None, ignore_times=True, sam def get_num_samples(self, segment_index=None): """Overrides the BaseSorting method, which requires a recording.""" segment_index = self._check_segment_index(segment_index) - n_samples = self._sorting_segments[segment_index].get_num_samples() + n_samples = self.segments[segment_index].get_num_samples() if self.has_recording(): # Sanity check assert n_samples == self._recording.get_num_samples(segment_index) return n_samples @@ -555,7 +555,7 @@ def __init__(self, parent_sorting: BaseSorting, recording_or_recording_list=None num_samples = [0] for recording in recording_list: - for recording_segment in recording._recording_segments: + for recording_segment in recording.segments: num_samples.append(recording_segment.get_num_samples()) cumsum_num_samples = np.cumsum(num_samples) @@ -563,8 +563,8 @@ def __init__(self, parent_sorting: BaseSorting, recording_or_recording_list=None sliced_parent_sorting = parent_sorting.frame_slice( start_frame=cumsum_num_samples[idx], end_frame=cumsum_num_samples[idx + 1] ) - sliced_segment = sliced_parent_sorting._sorting_segments[0] - self.add_sorting_segment(sliced_segment) + sliced_segment = sliced_parent_sorting.segments[0] + self.add_segment(sliced_segment) self._parent = parent_sorting self._kwargs = {"parent_sorting": parent_sorting, "recording_or_recording_list": recording_list} @@ -598,8 +598,8 @@ def __init__(self, sorting: BaseSorting, segment_indices: int | list[int]): ), f"'segment_index' must be between 0 and {num_segments - 1}" for segment_index in segment_indices: - sort_seg = sorting._sorting_segments[segment_index] - self.add_sorting_segment(sort_seg) + sort_seg = sorting.segments[segment_index] + self.add_segment(sort_seg) self._kwargs = {"sorting": sorting, "segment_indices": [int(s) for s in segment_indices]} diff --git a/src/spikeinterface/core/sortingfolder.py b/src/spikeinterface/core/sortingfolder.py index 026d6ea518..0da59f818e 100644 --- a/src/spikeinterface/core/sortingfolder.py +++ b/src/spikeinterface/core/sortingfolder.py @@ -40,7 +40,7 @@ def __init__(self, folder_path): self.spikes = np.load(folder_path / "spikes.npy") for segment_index in range(num_segments): - self.add_sorting_segment(SpikeVectorSortingSegment(self.spikes, segment_index, unit_ids)) + self.add_segment(SpikeVectorSortingSegment(self.spikes, segment_index, unit_ids)) # important trick : the cache is already spikes vector self._cached_spike_vector = self.spikes diff --git a/src/spikeinterface/core/tests/test_chunkable_tools.py b/src/spikeinterface/core/tests/test_chunkable_tools.py new file mode 100644 index 0000000000..3d686166e7 --- /dev/null +++ b/src/spikeinterface/core/tests/test_chunkable_tools.py @@ -0,0 +1,174 @@ +import numpy as np + +from spikeinterface.core import generate_recording + +from spikeinterface.core.binaryrecordingextractor import BinaryRecordingExtractor +from spikeinterface.core.generate import NoiseGeneratorRecording + + +from spikeinterface.core.chunkable_tools import ( + write_binary, + write_memory, + get_random_sample_slices, + get_chunks, +) + + +def test_write_binary(tmp_path): + # Test write_binary() with loop (n_jobs=1) + # Setup + sampling_frequency = 30_000 + num_channels = 2 + dtype = "float32" + + durations = [10.0] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw"] + + # Write binary recording + job_kwargs = dict(n_jobs=1) + write_binary(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype + ) + assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) + + +def test_write_binary_offset(tmp_path): + # Test write_binary() with loop (n_jobs=1) + # Setup + sampling_frequency = 30_000 + num_channels = 2 + dtype = "float32" + + durations = [10.0] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw"] + + # Write binary recording + job_kwargs = dict(n_jobs=1) + byte_offset = 125 + write_binary(recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, verbose=False, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, + sampling_frequency=sampling_frequency, + num_channels=num_channels, + dtype=dtype, + file_offset=byte_offset, + ) + assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) + + +def test_write_binary_parallel(tmp_path): + # Test write_binary() with parallel processing (n_jobs=2) + + # Setup + sampling_frequency = 30_000 + num_channels = 2 + dtype = "float32" + durations = [10.30, 3.5] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + dtype=dtype, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] + + # Write binary recording + job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn") + write_binary(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype + ) + for segment_index in range(recording.get_num_segments()): + binary_traces = recorder_binary.get_traces(segment_index=segment_index) + recording_traces = recording.get_traces(segment_index=segment_index) + assert np.allclose(binary_traces, recording_traces) + + +def test_write_binary_multiple_segment(tmp_path): + # Test write_binary() with multiple segments (n_jobs=2) + # Setup + sampling_frequency = 30_000 + num_channels = 10 + dtype = "float32" + + durations = [10.30, 3.5] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] + + # Write binary recording + job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn") + write_binary(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype + ) + + for segment_index in range(recording.get_num_segments()): + binary_traces = recorder_binary.get_traces(segment_index=segment_index) + recording_traces = recording.get_traces(segment_index=segment_index) + assert np.allclose(binary_traces, recording_traces) + + +def test_write_memory_recording(): + # 2 segments + recording = NoiseGeneratorRecording( + num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" + ) + recording = recording.save() + + # write with loop + traces_list, shms = write_memory(recording, dtype=None, verbose=True, n_jobs=1) + + traces_list, shms = write_memory( + recording, dtype=None, verbose=True, n_jobs=1, chunk_memory="100k", progress_bar=True + ) + + # write parallel + traces_list, shms = write_memory(recording, dtype=None, verbose=False, n_jobs=2, chunk_memory="100k") + # need to clean the buffer + del traces_list + for shm in shms: + shm.unlink() + + +def test_get_random_sample_slices(): + rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) + rec_slices = get_random_sample_slices( + rec, method="full_random", num_chunks_per_segment=20, chunk_duration="500ms", margin_frames=0, seed=0 + ) + assert len(rec_slices) == 40 + for seg_ind, start, stop in rec_slices: + assert stop - start == 500 + assert seg_ind in (0, 1) + + +def test_get_chunks(): + rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) + chunks = get_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) + assert chunks.shape == (50000, 1) diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index a4fdec0eff..f19be6329a 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -9,10 +9,10 @@ divide_segment_into_chunks, ensure_n_jobs, ensure_chunk_size, - ChunkRecordingExecutor, + ChunkExecutor, fix_job_kwargs, split_job_kwargs, - divide_recording_into_chunks, + divide_extractor_into_chunks, ) @@ -71,7 +71,7 @@ def test_ensure_chunk_size(): # Test edge case to define single chunk for n_jobs=1 chunk_size = ensure_chunk_size(recording, n_jobs=1, chunk_size=None) - chunks = divide_recording_into_chunks(recording, chunk_size) + chunks = divide_extractor_into_chunks(recording, chunk_size) assert len(chunks) == recording.get_num_segments() for chunk in chunks: segment_index, start_frame, end_frame = chunk @@ -96,13 +96,13 @@ def init_func(arg1, arg2, arg3): return worker_dict -def test_ChunkRecordingExecutor(): +def test_ChunkExecutor(): recording = generate_recording(num_channels=2) init_args = "a", 120, "yep" # no chunk - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, init_args, verbose=True, progress_bar=False, n_jobs=1, chunk_size=None ) processor.run() @@ -113,7 +113,7 @@ def gathering_result(res): pass # chunk + loop + gather_func - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, @@ -139,7 +139,7 @@ def __call__(self, res): gathering_func2 = GatherClass() # process + gather_func - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, @@ -153,12 +153,12 @@ def __call__(self, res): job_name="job_name", ) processor.run() - num_chunks = len(divide_recording_into_chunks(recording, processor.chunk_size)) + num_chunks = len(divide_extractor_into_chunks(recording, processor.chunk_size)) assert gathering_func2.pos == num_chunks # process spawn - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, @@ -174,7 +174,7 @@ def __call__(self, res): processor.run() # thread - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, @@ -256,7 +256,7 @@ def test_worker_index(): for i in range(2): # making this 2 times ensure to test that global variables are correctly reset for pool_engine in ("process", "thread"): - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func2, init_func2, @@ -320,7 +320,7 @@ def test_get_best_job_kwargs(): # test_divide_segment_into_chunks() # test_ensure_n_jobs() # test_ensure_chunk_size() - # test_ChunkRecordingExecutor() + # test_ChunkExecutor() # test_fix_job_kwargs() # test_split_job_kwargs() # test_worker_index() diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 028eaecf12..303f6c4080 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -4,7 +4,7 @@ import shutil from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording -from spikeinterface.core.job_tools import divide_recording_into_chunks +from spikeinterface.core.job_tools import divide_extractor_into_chunks # from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ( @@ -191,7 +191,7 @@ def test_run_node_pipeline(cache_folder_creation): unpickled_node = pickle.loads(pickled_node) -def test_skip_after_n_peaks_and_recording_slices(): +def test_skip_after_n_peaks_and_slices(): recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0], seed=2205) # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) @@ -220,11 +220,9 @@ def test_skip_after_n_peaks_and_recording_slices(): assert some_amplitudes.size < spikes.size # slices : 1 every 4 - recording_slices = divide_recording_into_chunks(recording, 10_000) - recording_slices = recording_slices[::4] - some_amplitudes = run_node_pipeline( - recording, nodes, job_kwargs, gather_mode="memory", recording_slices=recording_slices - ) + slices = divide_extractor_into_chunks(recording, 10_000) + slices = slices[::4] + some_amplitudes = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory", slices=slices) tolerance = 1.2 assert some_amplitudes.size < (spikes.size // 4) * tolerance @@ -234,4 +232,4 @@ def test_skip_after_n_peaks_and_recording_slices(): # folder = Path("./cache_folder/core") # test_run_node_pipeline(folder) - test_skip_after_n_peaks_and_recording_slices() + test_skip_after_n_peaks_and_slices() diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 02798099ec..8f66125575 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -1,17 +1,10 @@ -from pathlib import Path -import platform import numpy as np from spikeinterface.core import NumpyRecording, generate_recording - -from spikeinterface.core.binaryrecordingextractor import BinaryRecordingExtractor from spikeinterface.core.generate import NoiseGeneratorRecording from spikeinterface.core.recording_tools import ( - write_binary_recording, - write_memory_recording, - get_random_recording_slices, get_random_data_chunks, get_chunk_with_margin, get_closest_channels, @@ -23,162 +16,6 @@ ) -def test_write_binary_recording(tmp_path): - # Test write_binary_recording() with loop (n_jobs=1) - # Setup - sampling_frequency = 30_000 - num_channels = 2 - dtype = "float32" - - durations = [10.0] - recording = NoiseGeneratorRecording( - durations=durations, - num_channels=num_channels, - sampling_frequency=sampling_frequency, - strategy="tile_pregenerated", - ) - file_paths = [tmp_path / "binary01.raw"] - - # Write binary recording - job_kwargs = dict(n_jobs=1) - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) - - # Check if written data matches original data - recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype - ) - assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) - - -def test_write_binary_recording_offset(tmp_path): - # Test write_binary_recording() with loop (n_jobs=1) - # Setup - sampling_frequency = 30_000 - num_channels = 2 - dtype = "float32" - - durations = [10.0] - recording = NoiseGeneratorRecording( - durations=durations, - num_channels=num_channels, - sampling_frequency=sampling_frequency, - strategy="tile_pregenerated", - ) - file_paths = [tmp_path / "binary01.raw"] - - # Write binary recording - job_kwargs = dict(n_jobs=1) - byte_offset = 125 - write_binary_recording( - recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, verbose=False, **job_kwargs - ) - - # Check if written data matches original data - recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, - sampling_frequency=sampling_frequency, - num_channels=num_channels, - dtype=dtype, - file_offset=byte_offset, - ) - assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) - - -def test_write_binary_recording_parallel(tmp_path): - # Test write_binary_recording() with parallel processing (n_jobs=2) - - # Setup - sampling_frequency = 30_000 - num_channels = 2 - dtype = "float32" - durations = [10.30, 3.5] - recording = NoiseGeneratorRecording( - durations=durations, - num_channels=num_channels, - sampling_frequency=sampling_frequency, - dtype=dtype, - strategy="tile_pregenerated", - ) - file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] - - # Write binary recording - job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn") - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) - - # Check if written data matches original data - recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype - ) - for segment_index in range(recording.get_num_segments()): - binary_traces = recorder_binary.get_traces(segment_index=segment_index) - recording_traces = recording.get_traces(segment_index=segment_index) - assert np.allclose(binary_traces, recording_traces) - - -def test_write_binary_recording_multiple_segment(tmp_path): - # Test write_binary_recording() with multiple segments (n_jobs=2) - # Setup - sampling_frequency = 30_000 - num_channels = 10 - dtype = "float32" - - durations = [10.30, 3.5] - recording = NoiseGeneratorRecording( - durations=durations, - num_channels=num_channels, - sampling_frequency=sampling_frequency, - strategy="tile_pregenerated", - ) - file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] - - # Write binary recording - job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn") - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) - - # Check if written data matches original data - recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype - ) - - for segment_index in range(recording.get_num_segments()): - binary_traces = recorder_binary.get_traces(segment_index=segment_index) - recording_traces = recording.get_traces(segment_index=segment_index) - assert np.allclose(binary_traces, recording_traces) - - -def test_write_memory_recording(): - # 2 segments - recording = NoiseGeneratorRecording( - num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" - ) - recording = recording.save() - - # write with loop - traces_list, shms = write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1) - - traces_list, shms = write_memory_recording( - recording, dtype=None, verbose=True, n_jobs=1, chunk_memory="100k", progress_bar=True - ) - - # write parallel - traces_list, shms = write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, chunk_memory="100k") - # need to clean the buffer - del traces_list - for shm in shms: - shm.unlink() - - -def test_get_random_recording_slices(): - rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) - rec_slices = get_random_recording_slices( - rec, method="full_random", num_chunks_per_segment=20, chunk_duration="500ms", margin_frames=0, seed=0 - ) - assert len(rec_slices) == 40 - for seg_ind, start, stop in rec_slices: - assert stop - start == 500 - assert seg_ind in (0, 1) - - def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) chunks = get_random_data_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) @@ -251,7 +88,7 @@ def test_get_noise_levels_output(): def test_get_chunk_with_margin(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0]) - rec_seg = rec._recording_segments[0] + rec_seg = rec.segments[0] length = rec_seg.get_num_samples() #  rec_segment, start_frame, end_frame, channel_indices, sample_margin @@ -358,17 +195,10 @@ def test_do_recording_attributes_match(): if __name__ == "__main__": - # Create a temporary folder using the standard library - # import tempfile - - # with tempfile.TemporaryDirectory() as tmpdirname: - # tmp_path = Path(tmpdirname) - # test_write_binary_recording(tmp_path) - # test_write_memory_recording() - - test_get_random_recording_slices() - # test_get_random_data_chunks() - # test_get_closest_channels() - # test_get_noise_levels() - # test_get_noise_levels_output() - # test_order_channels_by_depth() + test_get_random_data_chunks() + test_get_closest_channels() + test_get_noise_levels() + test_get_noise_levels_output() + test_get_chunk_with_margin() + test_order_channels_by_depth() + test_do_recording_attributes_match() diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index f22939c33c..e03096ce14 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -64,7 +64,7 @@ def _get_time_vector_recording(self, raw_recording): times_recording.set_times(times=time_vector, segment_index=segment_index) assert np.array_equal( - times_recording._recording_segments[segment_index].time_vector, + times_recording.segments[segment_index].time_vector, time_vector, ), "time_vector was not properly set during test setup" @@ -84,7 +84,7 @@ def _get_t_start_recording(self, raw_recording): t_start = (segment_index + 1) * 100 all_t_starts.append(t_start + t_start_recording.get_times(segment_index)) - t_start_recording._recording_segments[segment_index].t_start = t_start + t_start_recording.segments[segment_index].t_start = t_start return (raw_recording, t_start_recording, all_t_starts) @@ -442,6 +442,6 @@ def test_shift_times_with_None_as_t_start(): """Ensures we can shift times even when t_stat is None which is interpeted as zero""" recording = generate_recording(num_channels=4, durations=[10]) - assert recording._recording_segments[0].t_start is None + assert recording.segments[0].t_start is None recording.shift_times(shift=1.0) # Shift by one seconds should not generate an error assert recording.get_start_time() == 1.0 diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index 404bae5924..20affd4f51 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -136,9 +136,9 @@ def __init__(self, sorting_list, renamed_unit_ids=None, sampling_frequency_max_d # add segments for i_seg in range(num_segments): - parent_segments = [sort._sorting_segments[i_seg] for sort in sorting_list] + parent_segments = [sort.segments[i_seg] for sort in sorting_list] sub_segment = UnitsAggregationSortingSegment(unit_map, parent_segments) - self.add_sorting_segment(sub_segment) + self.add_segment(sub_segment) self._sortings = sorting_list self._kwargs = {"sorting_list": sorting_list, "renamed_unit_ids": renamed_unit_ids} diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index 14454b2e20..9cbfdb3475 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -35,9 +35,9 @@ def __init__(self, parent_sorting, unit_ids=None, renamed_unit_ids=None): BaseSorting.__init__(self, sampling_frequency, self._renamed_unit_ids) - for parent_segment in self._parent_sorting._sorting_segments: + for parent_segment in self._parent_sorting.segments: sub_segment = UnitsSelectionSortingSegment(parent_segment, ids_conversion) - self.add_sorting_segment(sub_segment) + self.add_segment(sub_segment) parent_sorting.copy_metadata(self, only_main=False, ids=self._unit_ids) self._parent = parent_sorting diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 920de33c2f..81d79b3906 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -18,7 +18,7 @@ from spikeinterface.core.baserecording import BaseRecording from .baserecording import BaseRecording -from .job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc +from .job_tools import ChunkExecutor, _shared_job_kwargs_doc from .core_tools import make_shared_array from .job_tools import fix_job_kwargs @@ -295,16 +295,14 @@ def distribute_waveforms_to_buffers( ) if job_name is None: job_name = f"extract waveforms {mode} multi buffer" - processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs - ) + processor = ChunkExecutor(recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs) processor.run() distribute_waveforms_to_buffers.__doc__ = distribute_waveforms_to_buffers.__doc__.format(_shared_job_kwargs_doc) -# used by ChunkRecordingExecutor +# used by ChunkExecutor def _init_worker_distribute_buffers( recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_in_uV, inds_by_unit, mode, sparsity_mask ): @@ -351,7 +349,7 @@ def _init_worker_distribute_buffers( return worker_dict -# used by ChunkRecordingExecutor +# used by ChunkExecutor def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker recording = worker_dict["recording"] @@ -555,7 +553,7 @@ def extract_waveforms_to_single_buffer( if job_name is None: job_name = f"extract waveforms {mode} mono buffer" - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs ) processor.run() @@ -609,7 +607,7 @@ def _init_worker_distribute_single_buffer( return worker_dict -# used by ChunkRecordingExecutor +# used by ChunkExecutor def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker recording = worker_dict["recording"] @@ -933,7 +931,7 @@ def estimate_templates_with_accumulator( if job_name is None: job_name = "estimate_templates_with_accumulator" - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, init_args, job_name=job_name, verbose=verbose, need_worker_index=True, **job_kwargs ) processor.run() @@ -1019,7 +1017,7 @@ def _init_worker_estimate_templates( return worker_dict -# used by ChunkRecordingExecutor +# used by ChunkExecutor def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker recording = worker_dict["recording"] diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 162d67a458..5cb1281e5d 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -164,7 +164,7 @@ def __init__( time_kwargs["sampling_frequency"] = sampling_frequency rec_segment = ZarrRecordingSegment(self._root, trace_name, **time_kwargs) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) if load_compression_ratio: nbytes_segment = self._root[trace_name].nbytes @@ -297,7 +297,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, for segment_index in range(num_segments): soring_segment = SpikeVectorSortingSegment(spikes, segment_index, unit_ids) - self.add_sorting_segment(soring_segment) + self.add_segment(soring_segment) # load properties if "properties" in self._root: @@ -499,7 +499,7 @@ def add_recording_to_zarr_group( # save time vector if any t_starts = np.zeros(recording.get_num_segments(), dtype="float64") * np.nan - for segment_index, rs in enumerate(recording._recording_segments): + for segment_index, rs in enumerate(recording.segments): d = rs.get_times_kwargs() time_vector = d["time_vector"] @@ -559,7 +559,7 @@ def add_traces_to_zarr( from .job_tools import ( ensure_chunk_size, fix_job_kwargs, - ChunkRecordingExecutor, + ChunkExecutor, ) assert dataset_paths is not None, "Provide 'file_path'" @@ -596,13 +596,13 @@ def add_traces_to_zarr( func = _write_zarr_chunk init_func = _init_zarr_worker init_args = (recording, zarr_datasets, dtype) - executor = ChunkRecordingExecutor( + executor = ChunkExecutor( recording, func, init_func, init_args, verbose=verbose, job_name="write_zarr_recording", **job_kwargs ) executor.run() -# used by write_zarr_recording + ChunkRecordingExecutor +# used by write_zarr_recording + ChunkExecutor def _init_zarr_worker(recording, zarr_datasets, dtype): import zarr @@ -615,7 +615,7 @@ def _init_zarr_worker(recording, zarr_datasets, dtype): return worker_ctx -# used by write_zarr_recording + ChunkRecordingExecutor +# used by write_zarr_recording + ChunkExecutor def _write_zarr_chunk(segment_index, start_frame, end_frame, worker_ctx): import gc diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index df5bb7446c..43d300138b 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -70,9 +70,9 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy rm_dup_delta = None else: rm_dup_delta = int(delta_time_ms / 1000 * sampling_frequency) - for parent_segment in self._parent_sorting._sorting_segments: + for parent_segment in self._parent_sorting.segments: sub_segment = MergeUnitsSortingSegment(parent_segment, units_to_merge, new_unit_ids, rm_dup_delta) - self.add_sorting_segment(sub_segment) + self.add_segment(sub_segment) ann_keys = sorting._annotations.keys() self._annotations = deepcopy({k: sorting._annotations[k] for k in ann_keys}) diff --git a/src/spikeinterface/curation/remove_duplicated_spikes.py b/src/spikeinterface/curation/remove_duplicated_spikes.py index 508ac8c8cc..0b83e93c4a 100644 --- a/src/spikeinterface/curation/remove_duplicated_spikes.py +++ b/src/spikeinterface/curation/remove_duplicated_spikes.py @@ -39,8 +39,8 @@ def __init__(self, sorting: BaseSorting, censored_period_ms: float = 0.3, method censored_period = int(round(censored_period_ms * 1e-3 * sorting.get_sampling_frequency())) seed = np.random.randint(low=0, high=np.iinfo(np.int32).max) - for segment in sorting._sorting_segments: - self.add_sorting_segment( + for segment in sorting.segments: + self.add_segment( RemoveDuplicatedSpikesSortingSegment(segment, censored_period, sorting.unit_ids, method, seed) ) diff --git a/src/spikeinterface/curation/remove_excess_spikes.py b/src/spikeinterface/curation/remove_excess_spikes.py index 8663b0fdbd..8656583f0c 100644 --- a/src/spikeinterface/curation/remove_excess_spikes.py +++ b/src/spikeinterface/curation/remove_excess_spikes.py @@ -34,11 +34,9 @@ def __init__(self, sorting: BaseSorting, recording: BaseRecording) -> None: self._parent_sorting = sorting self._num_samples = np.empty(sorting.get_num_segments(), dtype=np.int64) for segment_index in range(sorting.get_num_segments()): - sorting_segment = sorting._sorting_segments[segment_index] + sorting_segment = sorting.segments[segment_index] self._num_samples[segment_index] = recording.get_num_samples(segment_index=segment_index) - self.add_sorting_segment( - RemoveExcessSpikesSortingSegment(sorting_segment, self._num_samples[segment_index]) - ) + self.add_segment(RemoveExcessSpikesSortingSegment(sorting_segment, self._num_samples[segment_index])) sorting.copy_metadata(self, only_main=False) self._parent = sorting diff --git a/src/spikeinterface/curation/splitunitsorting.py b/src/spikeinterface/curation/splitunitsorting.py index a05bab1e2b..dacad855fc 100644 --- a/src/spikeinterface/curation/splitunitsorting.py +++ b/src/spikeinterface/curation/splitunitsorting.py @@ -81,9 +81,9 @@ def __init__(self, sorting, split_unit_id, indices_list, new_unit_ids=None, prop np.isin(unchanged_units, self.unit_ids) ), "new_unit_ids should have a compatible format with the parent ids" - for si, parent_segment in enumerate(self._parent_sorting._sorting_segments): + for si, parent_segment in enumerate(self._parent_sorting.segments): sub_segment = SplitSortingUnitSegment(parent_segment, split_unit_id, indices_zero_based[si], new_unit_ids) - self.add_sorting_segment(sub_segment) + self.add_segment(sub_segment) # copy properties ann_keys = sorting._annotations.keys() diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 6559e89d52..a7da917c2e 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -9,7 +9,7 @@ import numpy as np from spikeinterface.core import SortingAnalyzer, BaseRecording, get_random_data_chunks -from spikeinterface.core.job_tools import fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc +from spikeinterface.core.job_tools import fix_job_kwargs, ChunkExecutor, _shared_job_kwargs_doc from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.exporters import export_to_phy @@ -260,7 +260,7 @@ def compute_rms( func = _compute_rms_chunk init_func = _init_rms_worker init_args = (recording,) - executor = ChunkRecordingExecutor( + executor = ChunkExecutor( recording, func, init_func, diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index d3a823ce3f..30bae691f6 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -8,7 +8,7 @@ import warnings from spikeinterface.core import ( - write_binary_recording, + write_binary, BinaryRecordingExtractor, BinaryFolderRecording, ChannelSparsity, @@ -135,7 +135,7 @@ def export_to_phy( if sorting_analyzer.has_recording(): if copy_binary: rec_path = output_folder / "recording.dat" - write_binary_recording(sorting_analyzer.recording, file_paths=rec_path, dtype=dtype, **job_kwargs) + write_binary(sorting_analyzer.recording, file_paths=rec_path, dtype=dtype, **job_kwargs) elif isinstance(sorting_analyzer.recording, BinaryRecordingExtractor): if isinstance(sorting_analyzer.recording, BinaryFolderRecording): bin_kwargs = sorting_analyzer.recording._bin_kwargs diff --git a/src/spikeinterface/extractors/alfsortingextractor.py b/src/spikeinterface/extractors/alfsortingextractor.py index f7b5401182..da4037cf4d 100644 --- a/src/spikeinterface/extractors/alfsortingextractor.py +++ b/src/spikeinterface/extractors/alfsortingextractor.py @@ -40,7 +40,7 @@ def __init__(self, folder_path, sampling_frequency=30000): unit_ids = np.arange(total_units) # in alf format, spikes.clusters index directly into clusters BaseSorting.__init__(self, unit_ids=unit_ids, sampling_frequency=sampling_frequency) sorting_segment = ALFSortingSegment(spikes["clusters"], spikes["samples"]) - self.add_sorting_segment(sorting_segment) + self.add_segment(sorting_segment) self.extra_requirements.append("ONE-api") self._kwargs = {"folder_path": str(Path(folder_path).resolve()), "sampling_frequency": sampling_frequency} diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index c70c49e8f8..2eebb81769 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -101,7 +101,7 @@ def __init__( self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=cbuffer.dtype ) recording_segment = CBinIblRecordingSegment(cbuffer, sampling_frequency, load_sync_channel) - self.add_recording_segment(recording_segment) + self.add_segment(recording_segment) self.extra_requirements.append("mtscomp") diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index 490ea61547..2690ab61a0 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -102,7 +102,7 @@ def __init__( BaseSorting.__init__(self, unit_ids=unit_ids, sampling_frequency=sampling_frequency) sorting_segment = CellExplorerSortingSegment(spiketrains_dict, unit_ids) - self.add_sorting_segment(sorting_segment) + self.add_segment(sorting_segment) self.extra_requirements.append(["pymatreader"]) diff --git a/src/spikeinterface/extractors/combinatoextractors.py b/src/spikeinterface/extractors/combinatoextractors.py index 35fce3a8e3..ee9fa1ff2f 100644 --- a/src/spikeinterface/extractors/combinatoextractors.py +++ b/src/spikeinterface/extractors/combinatoextractors.py @@ -79,7 +79,7 @@ def __init__(self, folder_path, sampling_frequency=None, user="simple", det_sign unit_counter = unit_counter + 1 unit_ids = np.arange(unit_counter, dtype="int64") BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.add_sorting_segment(CombinatoSortingSegment(spiketrains)) + self.add_segment(CombinatoSortingSegment(spiketrains)) self.set_property("unsorted", np.array([metadata[u]["group_type"] == 0 for u in range(unit_counter)])) self.set_property("artifact", np.array([metadata[u]["group_type"] == -1 for u in range(unit_counter)])) self._kwargs = {"folder_path": str(Path(folder_path).absolute()), "user": user, "det_sign": det_sign} diff --git a/src/spikeinterface/extractors/hdsortextractors.py b/src/spikeinterface/extractors/hdsortextractors.py index 86346957f8..e23b01f09f 100644 --- a/src/spikeinterface/extractors/hdsortextractors.py +++ b/src/spikeinterface/extractors/hdsortextractors.py @@ -91,7 +91,7 @@ def __init__(self, file_path, keep_good_only=True): BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.add_sorting_segment(HDSortSortingSegment(unit_ids, spiketrains)) + self.add_segment(HDSortSortingSegment(unit_ids, spiketrains)) # property templates = [] diff --git a/src/spikeinterface/extractors/herdingspikesextractors.py b/src/spikeinterface/extractors/herdingspikesextractors.py index de4929218b..30b0b06791 100644 --- a/src/spikeinterface/extractors/herdingspikesextractors.py +++ b/src/spikeinterface/extractors/herdingspikesextractors.py @@ -48,7 +48,7 @@ def __init__(self, file_path, load_unit_info=True): self.load_unit_info() BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.add_sorting_segment(HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids)) + self.add_segment(HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids)) self._kwargs = {"file_path": str(Path(file_path).absolute()), "load_unit_info": load_unit_info} self.extra_requirements.append("h5py") diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 93ab0ce417..cc22729642 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -263,7 +263,7 @@ def __init__( # init recording segment recording_segment = IblRecordingSegment(file_streamer=self._file_streamer, load_sync_channel=load_sync_channel) - self.add_recording_segment(recording_segment) + self.add_segment(recording_segment) self._kwargs = { "eid": eid, @@ -352,7 +352,7 @@ def __init__( unit_ids = clusters["cluster_id"][good_cluster_slice] BaseSorting.__init__(self, unit_ids=unit_ids, sampling_frequency=sr.fs) sorting_segment = ALFSortingSegment(spikes["clusters"], spikes["samples"]) - self.add_sorting_segment(sorting_segment) + self.add_segment(sorting_segment) if load_unit_properties: for key, val in clusters.items(): diff --git a/src/spikeinterface/extractors/klustaextractors.py b/src/spikeinterface/extractors/klustaextractors.py index 1e55072b54..2d1cfe3e09 100644 --- a/src/spikeinterface/extractors/klustaextractors.py +++ b/src/spikeinterface/extractors/klustaextractors.py @@ -125,7 +125,7 @@ def __init__(self, file_or_folder_path, exclude_cluster_groups=None): self.extra_requirements.append("h5py") - self.add_sorting_segment(KlustSortingSegment(unit_ids, spiketrains)) + self.add_segment(KlustSortingSegment(unit_ids, spiketrains)) self.set_property("group", groups) quality = [e.lower() for e in cluster_groups_name] diff --git a/src/spikeinterface/extractors/mclustextractors.py b/src/spikeinterface/extractors/mclustextractors.py index d611a1576a..83560728b2 100644 --- a/src/spikeinterface/extractors/mclustextractors.py +++ b/src/spikeinterface/extractors/mclustextractors.py @@ -71,7 +71,7 @@ def __init__(self, folder_path, sampling_frequency, sampling_frequency_raw=None) BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.add_sorting_segment(MClustSortingSegment(unit_ids, spiketrains)) + self.add_segment(MClustSortingSegment(unit_ids, spiketrains)) self._kwargs = { "folder_path": str(Path(folder_path).absolute()), "sampling_frequency": sampling_frequency, diff --git a/src/spikeinterface/extractors/mcsh5extractors.py b/src/spikeinterface/extractors/mcsh5extractors.py index a325b6aabd..85b6055ea5 100644 --- a/src/spikeinterface/extractors/mcsh5extractors.py +++ b/src/spikeinterface/extractors/mcsh5extractors.py @@ -54,7 +54,7 @@ def __init__(self, file_path, stream_id=0): recording_segment = MCSH5RecordingSegment( self._rf, stream_id, mcs_info["num_frames"], sampling_frequency=mcs_info["sampling_frequency"] ) - self.add_recording_segment(recording_segment) + self.add_segment(recording_segment) # set gain self.set_channel_gains(mcs_info["gain"]) diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index c1fc1144de..d1d3c87e61 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -12,7 +12,7 @@ from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class -from spikeinterface.core import write_binary_recording +from spikeinterface.core import write_binary from spikeinterface.core.job_tools import fix_job_kwargs @@ -54,7 +54,7 @@ def __init__(self, folder_path, raw_fname="raw.mda", params_fname="params.json", self, sampling_frequency=sampling_frequency, channel_ids=np.arange(num_channels), dtype=dtype ) rec_segment = MdaRecordingSegment(self._diskreadmda, sampling_frequency) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self.set_dummy_probe_from_locations(geom) self._kwargs = { "folder_path": str(Path(folder_path).absolute()), @@ -127,7 +127,7 @@ def write_recording( header = MdaHeader(dt0=dtype, dims0=(num_channels, num_frames)) header_size = header.header_size - write_binary_recording( + write_binary( recording, file_paths=save_file_path, dtype=dtype, @@ -207,13 +207,13 @@ def __init__(self, file_path, sampling_frequency): BaseSorting.__init__(self, unit_ids=unit_ids, sampling_frequency=sampling_frequency) sorting_segment = MdaSortingSegment(firings) - self.add_sorting_segment(sorting_segment) + self.add_segment(sorting_segment) # Store the max channel for each unit # Every spike assigned to a unit (label) has the same max channel # ref: https://github.com/SpikeInterface/spikeinterface/issues/3695#issuecomment-2663329006 max_channels = [] - segment = self._sorting_segments[0] + segment = self.segments[0] for unit_id in self.unit_ids: label_mask = segment._labels == unit_id # since all max channels are the same, we can just grab the first occurrence for the unit diff --git a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py index d66ce79aa3..df97947af6 100644 --- a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py +++ b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py @@ -329,7 +329,7 @@ def __init__( rec_segment = NeoRecordingSegment( self.neo_reader, self.block_index, segment_index, self.stream_index, self.inverted_gain ) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs.update(kwargs) @@ -434,7 +434,7 @@ def __init__( neo_returns_frames=self.neo_returns_frames, ) - self.add_sorting_segment(sorting_segment) + self.add_segment(sorting_segment) def _infer_sampling_frequency_from_analog_signal(self, stream_id: Optional[str] = None) -> float: """ diff --git a/src/spikeinterface/extractors/neoextractors/neuroscope.py b/src/spikeinterface/extractors/neoextractors/neuroscope.py index 298a2d6109..1bf42e5a08 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroscope.py +++ b/src/spikeinterface/extractors/neoextractors/neuroscope.py @@ -301,7 +301,7 @@ def __init__( all_unit_shank_ids += [shank_id] * len(new_unit_ids) BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=all_unit_ids) - self.add_sorting_segment(NeuroScopeSortingSegment(all_unit_ids, all_spiketrains)) + self.add_segment(NeuroScopeSortingSegment(all_unit_ids, all_spiketrains)) self.extra_requirements.append("lxml") diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 976e752a62..7057e85dc0 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -567,7 +567,7 @@ def __init__( electrical_series_data=segment_data, times_kwargs=times_kwargs, ) - self.add_recording_segment(recording_segment) + self.add_segment(recording_segment) # fetch and add main recording properties if use_pynwb: @@ -1072,7 +1072,7 @@ def __init__( sampling_frequency=self.sampling_frequency, t_start=self.t_start, ) - self.add_sorting_segment(sorting_segment) + self.add_segment(sorting_segment) # fetch and add sorting properties if load_unit_properties: @@ -1504,7 +1504,7 @@ def __init__( timeseries_data=segment_data, times_kwargs=times_kwargs, ) - self.add_recording_segment(recording_segment) + self.add_segment(recording_segment) if storage_options is not None and stream_mode == "zarr": warnings.warn( diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 46a8e4cecb..4024613e26 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -212,7 +212,7 @@ def __init__( self.annotate(phy_folder=str(phy_folder.resolve())) - self.add_sorting_segment(PhySortingSegment(spike_times_clean, spike_clusters_clean)) + self.add_segment(PhySortingSegment(spike_times_clean, spike_clusters_clean)) class PhySortingSegment(BaseSortingSegment): diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index 9d74b70fdf..e3b043cf36 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -11,7 +11,7 @@ BinaryRecordingExtractor, BaseSorting, BaseSortingSegment, - write_binary_recording, + write_binary, ) from spikeinterface.core.core_tools import define_function_from_class @@ -123,7 +123,7 @@ def write_recording(recording, save_path, initial_sorting_fn, dtype="float32", * # write recording recording_fn = (save_path / recording_name).absolute() - write_binary_recording(recording, file_paths=recording_fn, dtype=dtype, **job_kwargs) + write_binary(recording, file_paths=recording_fn, dtype=dtype, **job_kwargs) # write probe file probe_fn = (save_path / probe_name).absolute() @@ -179,7 +179,7 @@ def __init__(self, file_path, sampling_frequency, delimiter=","): BaseSorting.__init__(self, unit_ids=spike_clusters.keys(), sampling_frequency=sampling_frequency) sorting_segment = SHYBRIDSortingSegment(spike_clusters) - self.add_sorting_segment(sorting_segment) + self.add_segment(sorting_segment) self._kwargs = { "file_path": str(Path(file_path).absolute()), diff --git a/src/spikeinterface/extractors/sinapsrecordingextractors.py b/src/spikeinterface/extractors/sinapsrecordingextractors.py index 31a2a81f82..2130a5b724 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractors.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractors.py @@ -108,7 +108,7 @@ def __init__(self, file_path: str | Path): sampling_frequency=sinaps_info["sampling_frequency"], num_bits=sinaps_info["num_bits"], ) - self.add_recording_segment(recording_segment) + self.add_segment(recording_segment) # set gain self.set_channel_gains(sinaps_info["gain"]) diff --git a/src/spikeinterface/extractors/spykingcircusextractors.py b/src/spikeinterface/extractors/spykingcircusextractors.py index b8a1e5635e..463dca4d8e 100644 --- a/src/spikeinterface/extractors/spykingcircusextractors.py +++ b/src/spikeinterface/extractors/spykingcircusextractors.py @@ -77,7 +77,7 @@ def __init__(self, folder_path): unit_ids.append(int(temp.split("_")[-1])) BaseSorting.__init__(self, sample_rate, unit_ids) - self.add_sorting_segment(SpykingcircustSortingSegment(unit_ids, spiketrains)) + self.add_segment(SpykingcircustSortingSegment(unit_ids, spiketrains)) self._kwargs = {"folder_path": str(Path(folder_path).absolute())} self.extra_requirements.append("h5py") diff --git a/src/spikeinterface/extractors/tridesclousextractors.py b/src/spikeinterface/extractors/tridesclousextractors.py index ed66bb4a31..225a2ce6d7 100644 --- a/src/spikeinterface/extractors/tridesclousextractors.py +++ b/src/spikeinterface/extractors/tridesclousextractors.py @@ -51,7 +51,7 @@ def __init__(self, folder_path, chan_grp=None): for seg_num in range(dataio.nb_segment): # load all spike in memory (this avoid to lock the folder with memmap throug dataio all_spikes = dataio.get_spikes(seg_num=seg_num, chan_grp=chan_grp, i_start=None, i_stop=None).copy() - self.add_sorting_segment(TridesclousSortingSegment(all_spikes)) + self.add_segment(TridesclousSortingSegment(all_spikes)) self._kwargs = {"folder_path": str(Path(folder_path).absolute()), "chan_grp": chan_grp} self.extra_requirements.append("tridesclous") diff --git a/src/spikeinterface/extractors/waveclustextractors.py b/src/spikeinterface/extractors/waveclustextractors.py index 3d024910fa..9f279bec4d 100644 --- a/src/spikeinterface/extractors/waveclustextractors.py +++ b/src/spikeinterface/extractors/waveclustextractors.py @@ -42,7 +42,7 @@ def __init__(self, file_path, keep_good_only=True): BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.add_sorting_segment(WaveClustSortingSegment(unit_ids, spiketrains)) + self.add_segment(WaveClustSortingSegment(unit_ids, spiketrains)) self.set_property("unsorted", np.array([c == 0 for c in unit_ids])) self._kwargs = {"file_path": str(Path(file_path).absolute()), "keep_good_only": keep_good_only} diff --git a/src/spikeinterface/extractors/yassextractors.py b/src/spikeinterface/extractors/yassextractors.py index 7a76906acc..acb95010d6 100644 --- a/src/spikeinterface/extractors/yassextractors.py +++ b/src/spikeinterface/extractors/yassextractors.py @@ -46,7 +46,7 @@ def __init__(self, folder_path): # initialize sampling_frequency = self.config["recordings"]["sampling_rate"] BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.add_sorting_segment(YassSortingSegment(spiketrains)) + self.add_segment(YassSortingSegment(spiketrains)) self._kwargs = {"folder_path": str(Path(folder_path).absolute())} self.extra_requirements.append("pyyaml") diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 0989afe126..232eb6161a 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -466,9 +466,7 @@ def __init__( amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None # upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None - parent_recording_segment = ( - None if parent_recording is None else parent_recording._recording_segments[segment_index] - ) + parent_recording_segment = None if parent_recording is None else parent_recording.segments[segment_index] recording_segment = InjectDriftingTemplatesRecordingSegment( self.dtype, self.spike_vector[start:end], @@ -480,7 +478,7 @@ def __init__( displacement_indices[start:end], drifting_templates.templates_array_moved, ) - self.add_recording_segment(recording_segment) + self.add_segment(recording_segment) self.set_probe(drifting_templates.probe, in_place=True) diff --git a/src/spikeinterface/postprocessing/alignsorting.py b/src/spikeinterface/postprocessing/alignsorting.py index c24aa3e41b..de90503194 100644 --- a/src/spikeinterface/postprocessing/alignsorting.py +++ b/src/spikeinterface/postprocessing/alignsorting.py @@ -27,8 +27,8 @@ class AlignSortingExtractor(BaseSorting): def __init__(self, sorting, unit_peak_shifts): super().__init__(sorting.get_sampling_frequency(), sorting.unit_ids) - for segment in sorting._sorting_segments: - self.add_sorting_segment(AlignSortingSegment(segment, unit_peak_shifts)) + for segment in sorting.segments: + self.add_segment(AlignSortingSegment(segment, unit_peak_shifts)) sorting.copy_metadata(self, only_main=False) self._parent = sorting diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index ce8194f530..e766644313 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -3,7 +3,7 @@ import numpy as np from spikeinterface.core import ChannelSparsity -from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs +from spikeinterface.core.job_tools import ChunkExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs from spikeinterface.core.template_tools import get_template_extremum_channel diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index f5c1a74848..54bff0dbd3 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -13,7 +13,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs +from spikeinterface.core.job_tools import ChunkExecutor, _shared_job_kwargs_doc, fix_job_kwargs from spikeinterface.core.analyzer_extension_core import _inplace_sparse_realign_waveforms @@ -419,7 +419,7 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): unit_channels, pca_model, ) - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, init_args, job_name="extract PCs", verbose=verbose, **job_kwargs ) processor.run() diff --git a/src/spikeinterface/preprocessing/astype.py b/src/spikeinterface/preprocessing/astype.py index 2527993db8..b2231095b8 100644 --- a/src/spikeinterface/preprocessing/astype.py +++ b/src/spikeinterface/preprocessing/astype.py @@ -44,13 +44,13 @@ def __init__( if round is None: round = np.issubdtype(dtype, np.integer) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = AstypeRecordingSegment( parent_segment, dtype, round, ) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict( recording=recording, diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index ce23cb3f49..b3ef6bab8f 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -77,7 +77,7 @@ def __init__( self.parent_recording = parent_recording self.num_channels = n_pos_unique - for segment in parent_recording._recording_segments: + for segment in parent_recording.segments: recording_segment = AverageAcrossDirectionRecordingSegment( segment, self.num_channels, @@ -85,7 +85,7 @@ def __init__( n_chans_each_pos, dtype_, ) - self.add_recording_segment(recording_segment) + self.add_segment(recording_segment) self._kwargs = dict( parent_recording=parent_recording, diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index 6692a47758..00d61dd1a7 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -35,9 +35,9 @@ def __init__(self, recording, a_min=None, a_max=None): value_max = a_max BasePreprocessor.__init__(self, recording) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ClipRecordingSegment(parent_segment, a_min, value_min, a_max, value_max) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict(recording=recording, a_min=a_min, a_max=a_max) @@ -132,9 +132,9 @@ def __init__( value_max = fill_value BasePreprocessor.__init__(self, recording) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ClipRecordingSegment(parent_segment, a_min, value_min, a_max, value_max) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict( recording=recording, diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 3ab93fdf14..f450e8f647 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -135,11 +135,11 @@ def __init__( else: ref_channel_indices = None - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = CommonReferenceRecordingSegment( parent_segment, reference, operator, group_indices, ref_channel_indices, local_radius, neighbors, dtype_ ) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict( recording=recording, diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index 8a2ec1f839..051c3d706e 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -67,8 +67,8 @@ def __init__( BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency) - for parent_segment in recording._recording_segments: - self.add_recording_segment( + for parent_segment in recording.segments: + self.add_segment( DecimateRecordingSegment( parent_segment, decimated_sampling_frequency, diff --git a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py index aa3f9a34cc..56ccb37cc1 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py @@ -49,6 +49,8 @@ class DeepInterpolatedRecording(BasePreprocessor): The deepinterpolated recording extractor object """ + _preferred_mp_context = "spawn" + def __init__( self, recording, @@ -92,7 +94,7 @@ def __init__( self.model = model # add segment - for segment in recording._recording_segments: + for segment in recording.segments: recording_segment = DeepInterpolatedRecordingSegment( segment, self.model, @@ -103,9 +105,8 @@ def __init__( batch_size, predict_workers, ) - self.add_recording_segment(recording_segment) + self.add_segment(recording_segment) - self._preferred_mp_context = "spawn" self._kwargs = dict( recording=recording, model_path=str(model_path), diff --git a/src/spikeinterface/preprocessing/directional_derivative.py b/src/spikeinterface/preprocessing/directional_derivative.py index c945e4e6d4..34b67ecd2b 100644 --- a/src/spikeinterface/preprocessing/directional_derivative.py +++ b/src/spikeinterface/preprocessing/directional_derivative.py @@ -52,7 +52,7 @@ def __init__( BasePreprocessor.__init__(self, recording, dtype=dtype_) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = DirectionalDerivativeRecordingSegment( parent_segment, parent_channel_locations, @@ -61,7 +61,7 @@ def __init__( edge_order, dtype_, ) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict( recording=recording, diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 78542e1f37..03224be351 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -111,8 +111,8 @@ def __init__( self.set_channel_offsets(0) margin = int(margin_ms * fs / 1000.0) - for parent_segment in recording._recording_segments: - self.add_recording_segment( + for parent_segment in recording.segments: + self.add_segment( FilterRecordingSegment( parent_segment, filter_coeff, @@ -315,8 +315,8 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): sf = recording.get_sampling_frequency() margin = int(margin_ms * sf / 1000.0) - for parent_segment in recording._recording_segments: - self.add_recording_segment(FilterRecordingSegment(parent_segment, coeff, "ba", margin, dtype)) + for parent_segment in recording.segments: + self.add_segment(FilterRecordingSegment(parent_segment, coeff, "ba", margin, dtype)) self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str) diff --git a/src/spikeinterface/preprocessing/filter_gaussian.py b/src/spikeinterface/preprocessing/filter_gaussian.py index b053ef6533..c5554f9316 100644 --- a/src/spikeinterface/preprocessing/filter_gaussian.py +++ b/src/spikeinterface/preprocessing/filter_gaussian.py @@ -49,8 +49,8 @@ def __init__( if freq_min is None and freq_max is None: raise ValueError("At least one of `freq_min`,`freq_max` should be specified.") - for parent_segment in recording._recording_segments: - self.add_recording_segment(GaussianFilterRecordingSegment(parent_segment, freq_min, freq_max, margin_sd)) + for parent_segment in recording.segments: + self.add_segment(GaussianFilterRecordingSegment(parent_segment, freq_min, freq_max, margin_sd)) self._kwargs = {"recording": recording, "freq_min": freq_min, "freq_max": freq_max} diff --git a/src/spikeinterface/preprocessing/filter_opencl.py b/src/spikeinterface/preprocessing/filter_opencl.py index 1f4e18663b..60166d90a6 100644 --- a/src/spikeinterface/preprocessing/filter_opencl.py +++ b/src/spikeinterface/preprocessing/filter_opencl.py @@ -74,8 +74,8 @@ def __init__( dtype = "float32" executor = OpenCLFilterExecutor(coefficients, num_channels, dtype, margin) - for parent_segment in recording._recording_segments: - self.add_recording_segment(FilterOpenCLRecordingSegment(parent_segment, executor, margin)) + for parent_segment in recording.segments: + self.add_segment(FilterOpenCLRecordingSegment(parent_segment, executor, margin)) self._kwargs = dict( recording=recording, diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index 9228f5de12..5aec6ea735 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -121,7 +121,7 @@ def __init__( dtype = fix_dtype(recording, dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = HighPassSpatialFilterSegment( parent_segment, n_channel_pad, @@ -134,7 +134,7 @@ def __init__( order_r, dtype=dtype, ) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict( recording=recording, diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index 427aaa7437..cfc911fcce 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -67,11 +67,11 @@ def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=Non locations_bad = locations[self._bad_channel_idxs] weights = preprocessing_tools.get_kriging_channel_weights(locations_good, locations_bad, sigma_um, p) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = InterpolateBadChannelsSegment( parent_segment, self._good_channel_idxs, self._bad_channel_idxs, weights ) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict( recording=recording, bad_channel_ids=bad_channel_ids, p=p, sigma_um=sigma_um, weights=weights diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 6ad91baa8f..64e4381231 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -105,9 +105,9 @@ def __init__( BasePreprocessor.__init__(self, recording, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, dtype=self._dtype) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict( recording=recording, @@ -168,9 +168,9 @@ def __init__(self, recording, gain=1.0, offset=0.0, dtype="float32"): BasePreprocessor.__init__(self, recording, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, self._dtype) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict( recording=recording, @@ -213,9 +213,9 @@ def __init__(self, recording, mode="median", dtype="float32", **random_chunk_kwa BasePreprocessor.__init__(self, recording, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, dtype=self._dtype) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict( recording=recording, @@ -315,9 +315,9 @@ def __init__( self.set_property(key="gain_to_uV", values=np.ones(num_chans, dtype="float32")) self.set_property(key="offset_to_uV", values=np.zeros(num_chans, dtype="float32")) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, dtype=self._dtype) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict( recording=recording, dtype=np.dtype(self._dtype).str, mode=mode, gain=gain.tolist(), offset=offset.tolist() diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 872793a30e..dbd3ba1df1 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -63,9 +63,9 @@ def __init__(self, recording, margin_ms=40.0, inter_sample_shift=None, dtype=Non tmp_dtype = None BasePreprocessor.__init__(self, recording, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = PhaseShiftRecordingSegment(parent_segment, sample_shifts, margin, dtype, tmp_dtype) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) # for dumpability if inter_sample_shift is not None: diff --git a/src/spikeinterface/preprocessing/pipeline.py b/src/spikeinterface/preprocessing/pipeline.py index b04e60dc86..a5e54801f8 100644 --- a/src/spikeinterface/preprocessing/pipeline.py +++ b/src/spikeinterface/preprocessing/pipeline.py @@ -187,7 +187,10 @@ def get_preprocessing_dict_from_analyzer(analyzer_folder, format="auto", backend analyzer_folder = Path(analyzer_folder) if format == "auto": - if str(analyzer_folder).endswith(".zarr"): + analyzer_folder_str = str(analyzer_folder) + if analyzer_folder_str.endswith("/"): + analyzer_folder_str = analyzer_folder_str[:-1] + if analyzer_folder_str.endswith(".zarr"): format = "zarr" else: format = "binary_folder" diff --git a/src/spikeinterface/preprocessing/rectify.py b/src/spikeinterface/preprocessing/rectify.py index 96d68dda90..c7a3cfc09e 100644 --- a/src/spikeinterface/preprocessing/rectify.py +++ b/src/spikeinterface/preprocessing/rectify.py @@ -11,9 +11,9 @@ class RectifyRecording(BasePreprocessor): def __init__(self, recording): BasePreprocessor.__init__(self, recording) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = RectifyRecordingSegment(parent_segment) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict(recording=recording) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 60780856b0..ca66656b5f 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -199,13 +199,13 @@ def __init__( time_pad = None BasePreprocessor.__init__(self, recording) - for seg_index, parent_segment in enumerate(recording._recording_segments): + for seg_index, parent_segment in enumerate(recording.segments): triggers = list_triggers[seg_index] labels = list_labels[seg_index] rec_segment = RemoveArtifactsRecordingSegment( parent_segment, triggers, pad, mode, fit_samples, artifacts, labels, scale_amplitude, time_pad, sparsity ) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) list_triggers_ = [[int(trig) for trig in trig_seg] for trig_seg in list_triggers] if list_labels is not None: diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index 0fbf8e54e0..b6d963a35c 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -68,9 +68,9 @@ def __init__( BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate, dtype=dtype) # in case there was a time_vector, it will be dropped for sanity. - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: parent_segment.time_vector = None - self.add_recording_segment( + self.add_segment( ResampleRecordingSegment( parent_segment, resample_rate, diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index dca57b0c8b..e364f8491d 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -90,12 +90,12 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see noise_generator = None BasePreprocessor.__init__(self, recording) - for seg_index, parent_segment in enumerate(recording._recording_segments): + for seg_index, parent_segment in enumerate(recording.segments): periods = list_periods[seg_index] periods = np.asarray(periods, dtype="int64") periods = np.sort(periods, axis=0) rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed) self._kwargs.update(random_chunk_kwargs) diff --git a/src/spikeinterface/preprocessing/tests/test_clip.py b/src/spikeinterface/preprocessing/tests/test_clip.py index cea15722a0..96020692a1 100644 --- a/src/spikeinterface/preprocessing/tests/test_clip.py +++ b/src/spikeinterface/preprocessing/tests/test_clip.py @@ -41,7 +41,7 @@ def test_blank_saturation(): traces1 = rec1.get_traces(segment_index=0, channel_ids=["0"]) assert traces1.shape[1] == 1 # use a smaller value to be sure - a_min = rec1._recording_segments[0].a_min + a_min = rec1.segments[0].a_min assert np.all(traces1 >= a_min) diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index 8b37e7f4b9..d176163a64 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -83,7 +83,7 @@ def test_common_reference_channel_slicing(recording): start_frame = 0 end_frame = 10 - recording_segment_cmr = recording_cmr._recording_segments[0] + recording_segment_cmr = recording_cmr.segments[0] traces_cmr_all = recording_segment_cmr.get_traces( start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices ) @@ -93,7 +93,7 @@ def test_common_reference_channel_slicing(recording): assert np.all(traces_cmr_all[:, indices] == traces_cmr_sub) - recording_segment_car = recording_car._recording_segments[0] + recording_segment_car = recording_car.segments[0] traces_car_all = recording_segment_car.get_traces( start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices ) @@ -103,7 +103,7 @@ def test_common_reference_channel_slicing(recording): assert np.all(traces_car_all[:, indices] == traces_car_sub) - recording_segment_local = recording_local_car._recording_segments[0] + recording_segment_local = recording_local_car.segments[0] traces_local_all = recording_segment_local.get_traces( start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices ) diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index e9493145a6..141345ca46 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -66,7 +66,7 @@ def test_decimate_with_times(): # test with t_start rec = generate_recording(durations=[5, 10]) t_starts = [10, 20] - for t_start, rec_segment in zip(t_starts, rec._recording_segments): + for t_start, rec_segment in zip(t_starts, rec.segments): rec_segment.t_start = t_start decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) for segment_index in range(rec.get_num_segments()): diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index 31ea5f5523..29565bd336 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -262,7 +262,7 @@ def reduce_high_freq_power_in_non_noisy_channels(recording, is_noisy, not_noisy) """ from scipy.signal import welch - for iseg, __ in enumerate(recording._recording_segments): + for iseg, __ in enumerate(recording.segments): data = recording.get_traces(iseg).T num_samples = recording.get_num_samples(iseg) @@ -291,7 +291,7 @@ def add_dead_channels(recording, is_dead): data[:, is_dead] = np.random.normal( mean, std * 0.1, size=(is_dead.size, recording.get_num_samples(segment_index)) ).T - recording._recording_segments[segment_index]._traces = data + recording.segments[segment_index]._traces = data if __name__ == "__main__": diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 7c414df738..f4c0e4d166 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -366,8 +366,8 @@ def test_passed_W_and_M(self): whitened_recording = whiten(recording, W=test_W, M=test_M) for seg_idx in [0, 1]: - assert np.array_equal(whitened_recording._recording_segments[seg_idx].W, test_W) - assert np.array_equal(whitened_recording._recording_segments[seg_idx].M, test_M) + assert np.array_equal(whitened_recording.segments[seg_idx].W, test_W) + assert np.array_equal(whitened_recording.segments[seg_idx].M, test_M) assert whitened_recording._kwargs["W"] == test_W.tolist() assert whitened_recording._kwargs["M"] == test_M.tolist() diff --git a/src/spikeinterface/preprocessing/unsigned_to_signed.py b/src/spikeinterface/preprocessing/unsigned_to_signed.py index fdb5a663b1..312035663b 100644 --- a/src/spikeinterface/preprocessing/unsigned_to_signed.py +++ b/src/spikeinterface/preprocessing/unsigned_to_signed.py @@ -33,9 +33,9 @@ def __init__( BasePreprocessor.__init__(self, recording, dtype=dtype_signed) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = UnsignedToSignedRecordingSegment(parent_segment, dtype_signed, bit_depth) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict( recording=recording, diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index ea73425d34..8af5e7e121 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -103,9 +103,9 @@ def __init__( BasePreprocessor.__init__(self, recording, dtype=dtype_) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = WhitenRecordingSegment(parent_segment, W, M, dtype_, int_scale) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) self._kwargs = dict( recording=recording, diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index c06baf525a..4bae1040be 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -36,7 +36,7 @@ def __init__(self, recording: BaseRecording, padding_start: int = 0, padding_end self.padding_start = padding_start self.padding_end = padding_end self.fill_value = fill_value - for segment in recording._recording_segments: + for segment in recording.segments: recording_segment = TracePaddedRecordingSegment( segment, recording.get_num_channels(), @@ -45,7 +45,7 @@ def __init__(self, recording: BaseRecording, padding_start: int = 0, padding_end self.padding_end, self.fill_value, ) - self.add_recording_segment(recording_segment) + self.add_segment(recording_segment) self._kwargs = dict( parent_recording=recording, @@ -168,9 +168,9 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: self.parent_recording = recording self.num_channels = num_channels - for segment in recording._recording_segments: + for segment in recording.segments: recording_segment = ZeroChannelPaddedRecordingSegment(segment, self.num_channels, self.channel_mapping) - self.add_recording_segment(recording_segment) + self.add_segment(recording_segment) # only copy relevant metadata and properties recording.copy_metadata(self, only_main=True) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index ecedf92efb..c3f1051bf3 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -6,7 +6,7 @@ from packaging import version -from spikeinterface.core import write_binary_recording +from spikeinterface.core import write_binary from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs from .kilosortbase import KilosortBase from spikeinterface.sorters.basesorter import get_job_kwargs @@ -144,8 +144,8 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): if not recording.binary_compatible_with(time_axis=0, file_paths_length=1): # local copy needed binary_file_path = sorter_output_folder / "recording.dat" - write_binary_recording( - recording=recording, + write_binary( + recording, file_paths=[binary_file_path], **get_job_kwargs(params, verbose), ) diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index 97886868c1..6ec6e11f1f 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -11,7 +11,7 @@ from spikeinterface.sorters.utils import ShellScript, get_matlab_shell_name, get_bash_path from spikeinterface.sorters.basesorter import get_job_kwargs from spikeinterface.extractors.extractor_classes import KiloSortSortingExtractor -from spikeinterface.core import write_binary_recording +from spikeinterface.core import write_binary from spikeinterface.preprocessing.zero_channel_pad import TracePaddedRecording @@ -152,7 +152,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): ) else: padded_recording = recording - write_binary_recording( + write_binary( recording=padded_recording, file_paths=[binary_file_path], dtype="int16", diff --git a/src/spikeinterface/sorters/external/klusta.py b/src/spikeinterface/sorters/external/klusta.py index b6602ba6e2..8d31c51cd7 100644 --- a/src/spikeinterface/sorters/external/klusta.py +++ b/src/spikeinterface/sorters/external/klusta.py @@ -11,7 +11,7 @@ from probeinterface import write_prb -from spikeinterface.core import write_binary_recording +from spikeinterface.core import write_binary from spikeinterface.extractors.extractor_classes import KlustaSortingExtractor @@ -103,7 +103,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): # save binary file (chunk by chunk) into a new file raw_filename = sorter_output_folder / "recording.dat" dtype = "int16" - write_binary_recording(recording, file_paths=[raw_filename], dtype=dtype, **get_job_kwargs(params, verbose)) + write_binary(recording, file_paths=[raw_filename], dtype=dtype, **get_job_kwargs(params, verbose)) if p["detect_sign"] < 0: detect_sign = "negative" diff --git a/src/spikeinterface/sorters/external/pykilosort.py b/src/spikeinterface/sorters/external/pykilosort.py index 4b539cc7f7..9c25d821d9 100644 --- a/src/spikeinterface/sorters/external/pykilosort.py +++ b/src/spikeinterface/sorters/external/pykilosort.py @@ -9,7 +9,7 @@ import numpy as np from spikeinterface.extractors.extractor_classes import KiloSortSortingExtractor -from spikeinterface.core import write_binary_recording +from spikeinterface.core import write_binary from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs @@ -144,7 +144,7 @@ def _check_params(cls, recording, sorter_output_folder, params): def _setup_recording(cls, recording, sorter_output_folder, params, verbose): if not recording.binary_compatible_with(time_axis=0, file_paths_length=1): # local copy needed - write_binary_recording( + write_binary( recording, file_paths=sorter_output_folder / "recording.dat", **get_job_kwargs(params, verbose), diff --git a/src/spikeinterface/sorters/external/tridesclous.py b/src/spikeinterface/sorters/external/tridesclous.py index 65aadd1d55..1fc1ee4755 100644 --- a/src/spikeinterface/sorters/external/tridesclous.py +++ b/src/spikeinterface/sorters/external/tridesclous.py @@ -8,7 +8,7 @@ from spikeinterface.extractors.extractor_classes import TridesclousSortingExtractor from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs -from spikeinterface.core import write_binary_recording +from spikeinterface.core import write_binary from probeinterface import write_prb @@ -96,7 +96,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): num_channels = recording.get_num_channels() dtype = recording.get_dtype().str file_paths = [str(sorter_output_folder / f"raw_signals_{i}.raw") for i in range(num_seg)] - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **get_job_kwargs(params, verbose)) + write_binary(recording, file_paths=file_paths, dtype=dtype, **get_job_kwargs(params, verbose)) file_offset = 0 # initialize source and probe file diff --git a/src/spikeinterface/sorters/external/yass.py b/src/spikeinterface/sorters/external/yass.py index aabc8f77b4..ea2f4c62f5 100644 --- a/src/spikeinterface/sorters/external/yass.py +++ b/src/spikeinterface/sorters/external/yass.py @@ -10,7 +10,7 @@ from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs from spikeinterface.sorters.utils import ShellScript -from spikeinterface.core import write_binary_recording +from spikeinterface.core import write_binary from spikeinterface.extractors.extractor_classes import YassSortingExtractor @@ -152,7 +152,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): dtype = "int16" # HARD CODE THIS FOR YASS input_file_path = sorter_output_folder / "data.bin" - write_binary_recording(recording, file_paths=[input_file_path], dtype=dtype, **get_job_kwargs(params, verbose)) + write_binary(recording, file_paths=[input_file_path], dtype=dtype, **get_job_kwargs(params, verbose)) retrain = False if params["neural_nets_path"] is None: diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4ed6548ca1..1c75ef8cb9 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -11,7 +11,7 @@ from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion from spikeinterface.sortingcomponents.tools import ( cache_preprocessing, - get_shuffled_recording_slices, + get_shuffled_slices, _set_optimal_chunk_size, ) from spikeinterface.core.basesorting import minimum_spike_dtype @@ -266,7 +266,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(clustering_folder / "prototype.npy", prototype) if skip_peaks: - detect_pipeline_kwargs["recording_slices"] = get_shuffled_recording_slices( + detect_pipeline_kwargs["slices"] = get_shuffled_slices( recording_w, job_kwargs=job_kwargs, seed=params["seed"], diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index 2b57b8e431..e8950fcddd 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -28,7 +28,7 @@ def find_clusters_from_peaks( verbose : Bool, default: False If True, output is verbose job_kwargs : dict - Parameters for ChunkRecordingExecutor + Parameters for ChunkExecutor {method_doc} diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 0b4c1dfaa8..68a9239a25 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -37,11 +37,11 @@ def find_spikes_from_templates( If True then a dict is also returned is also returned pipeline_kwargs : dict Dict transmited to run_node_pipelines to handle fine details - like : gather_mode/folder/skip_after_n_peaks/recording_slices + like : gather_mode/folder/skip_after_n_peaks/slices verbose : Bool, default: False If True, output is verbose job_kwargs : dict - Parameters for ChunkRecordingExecutor + Parameters for ChunkExecutor {method_doc} diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index 6377fea566..6b8c5371b0 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -204,7 +204,7 @@ def __init__( # interpolation bins edges self.interpolation_time_bins_s = [] self.interpolation_time_bin_edges_s = [] - for segment_index, parent_segment in enumerate(recording._recording_segments): + for segment_index, parent_segment in enumerate(recording.segments): # in this case, interpolation_time_bin_size_s is set. s_end = parent_segment.get_num_samples() t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 4ddbc5eaee..4d8c51d0ad 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -225,7 +225,7 @@ def interpolate_motion_on_traces( # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing - # in ChunkRecordingExecutor) + # in ChunkExecutor) np.matmul(traces[frames_in_bin], drift_kernel, out=traces_corrected[frames_in_bin]) current_start_index = next_start_index @@ -424,7 +424,7 @@ def __init__( interpolation_time_bin_centers_s, interpolation_time_bin_edges_s ) - for segment_index, parent_segment in enumerate(recording._recording_segments): + for segment_index, parent_segment in enumerate(recording.segments): # finish the per-segment part of the time bin logic if interpolation_time_bin_centers_s is None: # in this case, interpolation_time_bin_size_s is set. @@ -452,7 +452,7 @@ def __init__( segment_interpolation_time_bin_edges_s, dtype=dtype_, ) - self.add_recording_segment(rec_segment) + self.add_segment(rec_segment) # this object is currently not JSON-serializable because the Motion obejct cannot be reloaded properly # see issue #3313 diff --git a/src/spikeinterface/sortingcomponents/peak_detection/main.py b/src/spikeinterface/sortingcomponents/peak_detection/main.py index 3941109e90..32aef955c4 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/main.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/main.py @@ -38,7 +38,7 @@ def detect_peaks( Important note, for flexibility, if method=None, then the method can be given inside the method_kwargs dict. pipeline_kwargs : dict Dict transmited to run_node_pipelines to handle fine details - like : gather_mode/folder/skip_after_n_peaks/recording_slices + like : gather_mode/folder/skip_after_n_peaks/slices verbose : Bool, default: False If True, output is verbose job_kwargs : dict | None, default None diff --git a/src/spikeinterface/sortingcomponents/peak_localization/main.py b/src/spikeinterface/sortingcomponents/peak_localization/main.py index 3e869db8ac..4c1ba3f812 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization/main.py +++ b/src/spikeinterface/sortingcomponents/peak_localization/main.py @@ -99,7 +99,7 @@ def localize_peaks( The number of milliseconds to include after the peak of the spike pipeline_kwargs : dict Dict transmited to run_node_pipelines to handle fine details - like : gather_mode/folder/skip_after_n_peaks/recording_slices + like : gather_mode/folder/skip_after_n_peaks/slices verbose : Bool, default: False If True, output is verbose job_kwargs : dict | None, default None diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 7680321722..5fa476771c 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -187,12 +187,12 @@ def get_prototype_and_waveforms_from_recording( nodes = [node0, node1] - recording_slices = get_shuffled_recording_slices(recording, job_kwargs=job_kwargs, seed=seed) + slices = get_shuffled_slices(recording, job_kwargs=job_kwargs, seed=seed) # res = detect_peaks( # recording, # pipeline_nodes=pipeline_nodes, # skip_after_n_peaks=n_peaks, - # recording_slices=recording_slices, + # slices=slices, # method="locally_exclusive", # method_kwargs=detection_kwargs, # job_kwargs=job_kwargs, @@ -203,7 +203,7 @@ def get_prototype_and_waveforms_from_recording( job_kwargs, job_name="get protoype waveforms", skip_after_n_peaks=n_peaks, - recording_slices=recording_slices, + slices=slices, ) rng = np.random.default_rng(seed) @@ -475,23 +475,23 @@ def create_sorting_analyzer_with_existing_templates( return sa -def get_shuffled_recording_slices(recording, job_kwargs=None, seed=None): +def get_shuffled_slices(recording, job_kwargs=None, seed=None): from spikeinterface.core.job_tools import ensure_chunk_size from spikeinterface.core.job_tools import divide_segment_into_chunks job_kwargs = fix_job_kwargs(job_kwargs) chunk_size = ensure_chunk_size(recording, **job_kwargs) - recording_slices = [] + slices = [] for segment_index in range(recording.get_num_segments()): num_frames = recording.get_num_samples(segment_index) chunks = divide_segment_into_chunks(num_frames, chunk_size) - recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) rng = np.random.default_rng(seed) - recording_slices = rng.permutation(recording_slices) + slices = rng.permutation(slices) - return recording_slices + return slices def clean_templates(