Skip to content
2 changes: 1 addition & 1 deletion src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
from .job_tools import (
get_best_job_kwargs,
ensure_n_jobs,
ensure_chunk_size,
ensure_recording_chunk_size,
ChunkRecordingExecutor,
split_job_kwargs,
fix_job_kwargs,
Expand Down
238 changes: 155 additions & 83 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def divide_recording_into_chunks(recording, chunk_size):
return recording_slices


def ensure_n_jobs(recording, n_jobs=1):
def ensure_n_jobs(extractor, n_jobs=1):
if n_jobs == -1:
n_jobs = os.cpu_count()
elif n_jobs == 0:
Expand All @@ -244,10 +244,10 @@ def ensure_n_jobs(recording, n_jobs=1):
print(f"Python {sys.version} does not support parallel processing")
n_jobs = 1

if not recording.check_if_memory_serializable():
if not extractor.check_if_memory_serializable():
if n_jobs != 1:
raise RuntimeError(
"Recording is not serializable to memory and can't be processed in parallel. "
"Extractor is not serializable to memory and can't be processed in parallel. "
"You can use the `rec = recording.save(folder=...)` function or set 'n_jobs' to 1."
)

Expand All @@ -270,7 +270,7 @@ def chunk_duration_to_chunk_size(chunk_duration, recording):
return chunk_size


def ensure_chunk_size(
def ensure_recording_chunk_size(
recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs
):
"""
Expand Down Expand Up @@ -330,70 +330,14 @@ def ensure_chunk_size(
return chunk_size


class ChunkRecordingExecutor:
class BaseChunkExecutor:
"""
Core class for parallel processing to run a "function" over chunks on a recording.

It supports running a function:
* in loop with chunk processing (low RAM usage)
* at once if chunk_size is None (high RAM usage)
* in parallel with ProcessPoolExecutor (higher speed)

The initializer ("init_func") allows to set a global context to avoid heavy serialization
(for examples, see implementation in `core.waveform_tools`).

Parameters
----------
recording : RecordingExtractor
The recording to be processed
func : function
Function that runs on each chunk
init_func : function
Initializer function to set the global context (accessible by "func")
init_args : tuple
Arguments for init_func
verbose : bool
If True, output is verbose
job_name : str, default: ""
Job name
progress_bar : bool, default: False
If True, a progress bar is printed to monitor the progress of the process
handle_returns : bool, default: False
If True, the function can return values
gather_func : None or callable, default: None
Optional function that is called in the main thread and retrieves the results of each worker.
This function can be used instead of `handle_returns` to implement custom storage on-the-fly.
pool_engine : "process" | "thread", default: "thread"
If n_jobs>1 then use ProcessPoolExecutor or ThreadPoolExecutor
n_jobs : int, default: 1
Number of jobs to be used. Use -1 to use as many jobs as number of cores
total_memory : str, default: None
Total memory (RAM) to use (e.g. "1G", "500M")
chunk_memory : str, default: None
Memory per chunk (RAM) to use (e.g. "1G", "500M")
chunk_size : int or None, default: None
Size of each chunk in number of samples. If "total_memory" or "chunk_memory" are used, it is ignored.
chunk_duration : str or float or None
Chunk duration in s if float or with units if str (e.g. "1s", "500ms")
mp_context : "fork" | "spawn" | None, default: None
"fork" or "spawn". If None, the context is taken by the recording.get_preferred_mp_context().
"fork" is only safely available on LINUX systems.
max_threads_per_worker : int or None, default: None
Limit the number of thread per process using threadpoolctl modules.
This used only when n_jobs>1
If None, no limits.
need_worker_index : bool, default False
If True then each worker will also have a "worker_index" injected in the local worker dict.

Returns
-------
res : list
If "handle_returns" is True, the results for each chunk process
Base class for chunk execution.
"""

def __init__(
self,
recording,
extractor: "BaseExtractor",
func,
init_func,
init_args,
Expand All @@ -412,14 +356,15 @@ def __init__(
max_threads_per_worker=1,
need_worker_index=False,
):
self.recording = recording
self.extractor = extractor
self.func = func
self.init_func = init_func
self.init_args = init_args

if pool_engine == "process":
if mp_context is None:
mp_context = recording.get_preferred_mp_context()
if hasattr(extractor, "get_preferred_mp_context"):
mp_context = extractor.get_preferred_mp_context()
if mp_context is not None and platform.system() == "Windows":
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"
elif mp_context == "fork" and platform.system() == "Darwin":
Expand All @@ -433,9 +378,8 @@ def __init__(
self.handle_returns = handle_returns
self.gather_func = gather_func

self.n_jobs = ensure_n_jobs(recording, n_jobs=n_jobs)
self.chunk_size = ensure_chunk_size(
recording,
self.n_jobs = ensure_n_jobs(self.extractor, n_jobs=n_jobs)
self.chunk_size = self.ensure_chunk_size(
total_memory=total_memory,
chunk_size=chunk_size,
chunk_memory=chunk_memory,
Expand All @@ -450,9 +394,9 @@ def __init__(
self.need_worker_index = need_worker_index

if verbose:
chunk_memory = self.chunk_size * recording.get_num_channels() * np.dtype(recording.get_dtype()).itemsize
chunk_memory = self.get_chunk_memory()
total_memory = chunk_memory * self.n_jobs
chunk_duration = self.chunk_size / recording.get_sampling_frequency()
chunk_duration = self.chunk_size / extractor.sampling_frequency
chunk_memory_str = convert_bytes_to_str(chunk_memory)
total_memory_str = convert_bytes_to_str(total_memory)
chunk_duration_str = convert_seconds_to_str(chunk_duration)
Expand All @@ -467,13 +411,22 @@ def __init__(
f"chunk_duration={chunk_duration_str}",
)

def run(self, recording_slices=None):
def get_chunk_memory(self):
raise NotImplementedError

def ensure_chunk_size(
self, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs
):
raise NotImplementedError

def run(self, slices=None):
"""
Runs the defined jobs.
"""

if recording_slices is None:
recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size)
if slices is None:
# TODO: rename
slices = divide_recording_into_chunks(self.extractor, self.chunk_size)

if self.handle_returns:
returns = []
Expand All @@ -482,23 +435,21 @@ def run(self, recording_slices=None):

if self.n_jobs == 1:
if self.progress_bar:
recording_slices = tqdm(
recording_slices, desc=f"{self.job_name} (no parallelization)", total=len(recording_slices)
)
slices = tqdm(slices, desc=f"{self.job_name} (no parallelization)", total=len(slices))

worker_dict = self.init_func(*self.init_args)
if self.need_worker_index:
worker_dict["worker_index"] = 0

for segment_index, frame_start, frame_stop in recording_slices:
for segment_index, frame_start, frame_stop in slices:
res = self.func(segment_index, frame_start, frame_stop, worker_dict)
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)

else:
n_jobs = min(self.n_jobs, len(recording_slices))
n_jobs = min(self.n_jobs, len(slices))

if self.pool_engine == "process":

Expand Down Expand Up @@ -526,11 +477,11 @@ def run(self, recording_slices=None):
array_pid,
),
) as executor:
results = executor.map(process_function_wrapper, recording_slices)
results = executor.map(process_function_wrapper, slices)

if self.progress_bar:
results = tqdm(
results, desc=f"{self.job_name} (workers: {n_jobs} processes)", total=len(recording_slices)
results, desc=f"{self.job_name} (workers: {n_jobs} processes)", total=len(slices)
)

for res in results:
Expand All @@ -549,7 +500,7 @@ def run(self, recording_slices=None):
if self.progress_bar:
# here the tqdm threading do not work (maybe collision) so we need to create a pbar
# before thread spawning
pbar = tqdm(desc=f"{self.job_name} (workers: {n_jobs} threads)", total=len(recording_slices))
pbar = tqdm(desc=f"{self.job_name} (workers: {n_jobs} threads)", total=len(slices))

if self.need_worker_index:
lock = threading.Lock()
Expand All @@ -570,8 +521,8 @@ def run(self, recording_slices=None):
),
) as executor:

recording_slices2 = [(thread_local_data,) + tuple(args) for args in recording_slices]
results = executor.map(thread_function_wrapper, recording_slices2)
slices2 = [(thread_local_data,) + tuple(args) for args in slices]
results = executor.map(thread_function_wrapper, slices2)

for res in results:
if self.progress_bar:
Expand All @@ -590,6 +541,127 @@ def run(self, recording_slices=None):
return returns


class ChunkRecordingExecutor(BaseChunkExecutor):
"""
Core class for parallel processing to run a "function" over chunks on a recording.

It supports running a function:
* in loop with chunk processing (low RAM usage)
* at once if chunk_size is None (high RAM usage)
* in parallel with ProcessPoolExecutor (higher speed)

The initializer ("init_func") allows to set a global context to avoid heavy serialization
(for examples, see implementation in `core.waveform_tools`).

Parameters
----------
recording : RecordingExtractor
The recording to be processed
func : function
Function that runs on each chunk
init_func : function
Initializer function to set the global context (accessible by "func")
init_args : tuple
Arguments for init_func
verbose : bool
If True, output is verbose
job_name : str, default: ""
Job name
progress_bar : bool, default: False
If True, a progress bar is printed to monitor the progress of the process
handle_returns : bool, default: False
If True, the function can return values
gather_func : None or callable, default: None
Optional function that is called in the main thread and retrieves the results of each worker.
This function can be used instead of `handle_returns` to implement custom storage on-the-fly.
pool_engine : "process" | "thread", default: "thread"
If n_jobs>1 then use ProcessPoolExecutor or ThreadPoolExecutor
n_jobs : int, default: 1
Number of jobs to be used. Use -1 to use as many jobs as number of cores
total_memory : str, default: None
Total memory (RAM) to use (e.g. "1G", "500M")
chunk_memory : str, default: None
Memory per chunk (RAM) to use (e.g. "1G", "500M")
chunk_size : int or None, default: None
Size of each chunk in number of samples. If "total_memory" or "chunk_memory" are used, it is ignored.
chunk_duration : str or float or None
Chunk duration in s if float or with units if str (e.g. "1s", "500ms")
mp_context : "fork" | "spawn" | None, default: None
"fork" or "spawn". If None, the context is taken by the recording.get_preferred_mp_context().
"fork" is only safely available on LINUX systems.
max_threads_per_worker : int or None, default: None
Limit the number of thread per process using threadpoolctl modules.
This used only when n_jobs>1
If None, no limits.
need_worker_index : bool, default False
If True then each worker will also have a "worker_index" injected in the local worker dict.

Returns
-------
res : list
If "handle_returns" is True, the results for each chunk process
"""

def __init__(
self,
recording,
func,
init_func,
init_args,
verbose=False,
progress_bar=False,
handle_returns=False,
gather_func=None,
pool_engine="thread",
n_jobs=1,
total_memory=None,
chunk_size=None,
chunk_memory=None,
chunk_duration=None,
mp_context=None,
job_name="",
max_threads_per_worker=1,
need_worker_index=False,
):
self.recording = recording
super().__init__(
recording,
func,
init_func,
init_args,
verbose=verbose,
progress_bar=progress_bar,
handle_returns=handle_returns,
gather_func=gather_func,
pool_engine=pool_engine,
n_jobs=n_jobs,
total_memory=total_memory,
chunk_size=chunk_size,
chunk_memory=chunk_memory,
chunk_duration=chunk_duration,
mp_context=mp_context,
job_name=job_name,
max_threads_per_worker=max_threads_per_worker,
need_worker_index=need_worker_index,
)

def run(self, recording_slices=None):
"""
Runs the defined jobs.
"""
return super().run(slices=recording_slices)

def get_chunk_memory(self):
return self.chunk_size * self.recording.get_dtype().itemsize * self.recording.get_num_channels()

def ensure_chunk_size(
self, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs
):
return ensure_recording_chunk_size(
self.recording, total_memory, chunk_size, chunk_memory, chunk_duration, n_jobs, **other_kwargs
)


class WorkerFuncWrapper:
"""
small wrapper that handles:
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .core_tools import add_suffix, make_shared_array
from .job_tools import (
ensure_chunk_size,
ensure_recording_chunk_size,
ensure_n_jobs,
divide_segment_into_chunks,
fix_job_kwargs,
Expand Down Expand Up @@ -183,7 +183,7 @@ def write_binary_recording_file_handle(
dtype = recording.get_dtype()

job_kwargs = fix_job_kwargs(job_kwargs)
chunk_size = ensure_chunk_size(recording, **job_kwargs)
chunk_size = ensure_recording_chunk_size(recording, **job_kwargs)

if chunk_size is not None and time_axis == 1:
print("Chunking disabled due to 'time_axis' == 1")
Expand Down Expand Up @@ -410,7 +410,7 @@ def write_to_h5_dataset_format(

dset = file_handle.create_dataset(dataset_path, shape=shape, dtype=dtype_file)

chunk_size = ensure_chunk_size(recording, chunk_size=chunk_size, chunk_memory=chunk_memory, n_jobs=1)
chunk_size = ensure_recording_chunk_size(recording, chunk_size=chunk_size, chunk_memory=chunk_memory, n_jobs=1)

if chunk_size is None:
# Handle deprecated return_scaled parameter
Expand Down
Loading
Loading