Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions src/spikeinterface/preprocessing/silence_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment

from spikeinterface.core import get_random_data_chunks, get_noise_levels
from spikeinterface.core import get_noise_levels
from spikeinterface.core.generate import NoiseGeneratorRecording
from spikeinterface.core.job_tools import split_job_kwargs


class SilencedPeriodsRecording(BasePreprocessor):
Expand Down Expand Up @@ -36,15 +37,23 @@ class SilencedPeriodsRecording(BasePreprocessor):
- "noise": The periods are filled with a gaussion noise that has the
same variance that the one in the recordings, on a per channel
basis
**random_chunk_kwargs : Keyword arguments for `spikeinterface.core.get_random_data_chunk()` function
**noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function

Returns
-------
silence_recording : SilencedPeriodsRecording
The recording extractor after silencing some periods
"""

def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, seed=None, **random_chunk_kwargs):
def __init__(
self,
recording,
list_periods,
mode="zeros",
noise_levels=None,
seed=None,
**noise_levels_kwargs,
):
available_modes = ("zeros", "noise")
num_seg = recording.get_num_segments()

Expand All @@ -71,11 +80,9 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see

if mode in ["noise"]:
if noise_levels is None:
random_slices_kwargs = random_chunk_kwargs.copy()
random_slices_kwargs["seed"] = seed
noise_levels = get_noise_levels(
recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs
)
noise_levels_kwargs["return_in_uV"] = False
noise_levels_kwargs["seed"] = seed
noise_levels = get_noise_levels(recording, **noise_levels_kwargs)
noise_generator = NoiseGeneratorRecording(
num_channels=recording.get_num_channels(),
sampling_frequency=recording.sampling_frequency,
Expand All @@ -97,8 +104,10 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see
rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index)
self.add_recording_segment(rec_segment)

self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed)
self._kwargs.update(random_chunk_kwargs)
self._kwargs = dict(
recording=recording, list_periods=list_periods, mode=mode, seed=seed, noise_levels=noise_levels
)
self._kwargs.update(noise_levels_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't play with the _kwargs. Will this update lead to incompatible keys between spikeinterface versions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is a good question, and we should discuss this PR with @samuelgarcia . Indeed, moving the computation of noise levels into the detect_peaks() and not in the node themselves is the only option to control finely the job_kwargs. Not sure we need to save the noise_level_kwargs because the noise_levels are cached per recording I think, and not recomputed during parallel processing. I'll double check with @samuelgarcia



class SilencedPeriodsRecordingSegment(BasePreprocessorSegment):
Expand Down
13 changes: 12 additions & 1 deletion src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ def detect_peaks(
method_kwargs, job_kwargs = split_job_kwargs(kwargs)
job_kwargs["mp_context"] = method_class.preferred_mp_context

if method_class.need_noise_levels:
random_chunk_kwargs = method_kwargs.pop("random_chunk_kwargs", {})
method_kwargs["noise_levels"] = get_noise_levels(
recording, return_in_uV=False, **random_chunk_kwargs, **job_kwargs
)

node0 = method_class(recording, **method_kwargs)
nodes = [node0]

Expand Down Expand Up @@ -384,6 +390,7 @@ class DetectPeakByChannel(PeakDetectorWrapper):

name = "by_channel"
engine = "numpy"
need_noise_levels = True
preferred_mp_context = None
params_doc = """
peak_sign: "neg" | "pos" | "both", default: "neg"
Expand Down Expand Up @@ -466,6 +473,7 @@ class DetectPeakByChannelTorch(PeakDetectorWrapper):

name = "by_channel_torch"
engine = "torch"
need_noise_levels = True
preferred_mp_context = "spawn"
params_doc = """
peak_sign: "neg" | "pos" | "both", default: "neg"
Expand Down Expand Up @@ -509,7 +517,6 @@ def check_params(
assert peak_sign in ("both", "neg", "pos")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

if noise_levels is None:
noise_levels = get_noise_levels(recording, return_in_uV=False, **random_chunk_kwargs)
abs_thresholds = noise_levels * detect_threshold
Expand Down Expand Up @@ -538,6 +545,7 @@ class DetectPeakLocallyExclusive(PeakDetectorWrapper):

name = "locally_exclusive"
engine = "numba"
need_noise_levels = True
preferred_mp_context = None
params_doc = (
DetectPeakByChannel.params_doc
Expand Down Expand Up @@ -633,6 +641,7 @@ class DetectPeakMatchedFiltering(PeakDetector):

name = "matched_filtering"
engine = "numba"
need_noise_levels = False
preferred_mp_context = None
params_doc = (
DetectPeakByChannel.params_doc
Expand Down Expand Up @@ -780,6 +789,7 @@ class DetectPeakLocallyExclusiveTorch(PeakDetectorWrapper):

name = "locally_exclusive_torch"
engine = "torch"
need_noise_levels = True
preferred_mp_context = "spawn"
params_doc = (
DetectPeakByChannel.params_doc
Expand Down Expand Up @@ -1069,6 +1079,7 @@ def _torch_detect_peaks(traces, peak_sign, abs_thresholds, exclude_sweep_size=5,
class DetectPeakLocallyExclusiveOpenCL(PeakDetectorWrapper):
name = "locally_exclusive_cl"
engine = "opencl"
need_noise_levels = True
preferred_mp_context = None
params_doc = (
DetectPeakLocallyExclusive.params_doc
Expand Down
Loading