diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 78542e1f37..38485aff4a 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -43,7 +43,7 @@ class FilterRecording(BasePreprocessor): btype : "bandpass" | "highpass", default: "bandpass" Type of the filter margin_ms : float, default: 5.0 - Margin in ms on border to avoid border effect + Margin in ms on border to avoid border effect. coeff : array | None, default: None Filter coefficients in the filter_mode form. dtype : dtype or None, default: None @@ -78,7 +78,8 @@ def __init__( filter_order=5, ftype="butter", filter_mode="sos", - margin_ms=5.0, + margin_ms=5, + max_margin_s=5, add_reflect_padding=False, coeff=None, dtype=None, @@ -138,6 +139,22 @@ def __init__( direction=direction, ) + def adjust_margin_ms_for_highpass(self, freq_min, max_margin_s): + # compute margin as 3 times the period of the highpass cutoff + margin_ms = 3 * (1000.0 / freq_min) + # limit max margin + max_margin_ms = max_margin_s * 1000.0 + if margin_ms > max_margin_ms: + margin_ms = max_margin_ms + return margin_ms + + def adjust_margin_ms_for_notch(self, max_margin_s, q, f0): + margin_ms = (3 / np.pi) * (q / f0) * 1000.0 + max_margin_ms = max_margin_s * 1000.0 + if margin_ms < max_margin_ms: + margin_ms = max_margin_ms + return margin_ms + class FilterRecordingSegment(BasePreprocessorSegment): def __init__( @@ -217,8 +234,11 @@ class BandpassFilterRecording(FilterRecording): The highpass cutoff frequency in Hz freq_max : float The lowpass cutoff frequency in Hz - margin_ms : float - Margin in ms on border to avoid border effect + margin_ms : float | str, default: "auto" + Margin in ms on border to avoid border effect. + If "auto", margin is computed as 3 times the filter highpass cutoff period. + max_margin_s : float, default: 5 + Maximum margin in seconds when margin_ms is set to "auto". dtype : dtype or None The dtype of the returned traces. If None, the dtype of the parent recording is used {} @@ -229,7 +249,11 @@ class BandpassFilterRecording(FilterRecording): The bandpass-filtered recording extractor object """ - def __init__(self, recording, freq_min=300.0, freq_max=6000.0, margin_ms=5.0, dtype=None, **filter_kwargs): + def __init__( + self, recording, freq_min=300.0, freq_max=6000.0, margin_ms="auto", max_margin_s=5, dtype=None, **filter_kwargs + ): + if margin_ms == "auto": + margin_ms = self.adjust_margin_ms_for_highpass(freq_min, max_margin_s) FilterRecording.__init__( self, recording, band=[freq_min, freq_max], margin_ms=margin_ms, dtype=dtype, **filter_kwargs ) @@ -250,8 +274,11 @@ class HighpassFilterRecording(FilterRecording): The recording extractor to be re-referenced freq_min : float The highpass cutoff frequency in Hz - margin_ms : float - Margin in ms on border to avoid border effect + margin_ms : float | str, default: "auto" + Margin in ms on border to avoid border effect. + If "auto", margin is computed as 3 times the filter highpass cutoff period. + max_margin_s : float, default: 5 + Maximum margin in seconds when margin_ms is set to "auto". dtype : dtype or None The dtype of the returned traces. If None, the dtype of the parent recording is used {} @@ -262,7 +289,9 @@ class HighpassFilterRecording(FilterRecording): The highpass-filtered recording extractor object """ - def __init__(self, recording, freq_min=300.0, margin_ms=5.0, dtype=None, **filter_kwargs): + def __init__(self, recording, freq_min=300.0, margin_ms="auto", max_margin_s=5, dtype=None, **filter_kwargs): + if margin_ms == "auto": + margin_ms = self.adjust_margin_ms_for_highpass(freq_min, max_margin_s) FilterRecording.__init__( self, recording, band=freq_min, margin_ms=margin_ms, dtype=dtype, btype="highpass", **filter_kwargs ) @@ -271,7 +300,7 @@ def __init__(self, recording, freq_min=300.0, margin_ms=5.0, dtype=None, **filte self._kwargs.update(filter_kwargs) -class NotchFilterRecording(BasePreprocessor): +class NotchFilterRecording(FilterRecording): """ Parameters ---------- @@ -283,8 +312,10 @@ class NotchFilterRecording(BasePreprocessor): The quality factor of the notch filter dtype : None | dtype, default: None dtype of recording. If None, will take from `recording` - margin_ms : float, default: 5.0 + margin_ms : float | str, default: "auto" Margin in ms on border to avoid border effect + max_margin_s : float, default: 5 + Maximum margin in seconds when margin_ms is set to "auto". Returns ------- @@ -292,16 +323,16 @@ class NotchFilterRecording(BasePreprocessor): The notch-filtered recording extractor object """ - def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): - # coeef is 'ba' type - fn = 0.5 * float(recording.get_sampling_frequency()) + def __init__(self, recording, freq=3000, q=30, margin_ms="auto", max_margin_s=5, dtype=None, **filter_kwargs): import scipy.signal + if margin_ms == "auto": + margin_ms = self.adjust_margin_ms_for_notch(max_margin_s, q, freq) + + fn = 0.5 * float(recording.get_sampling_frequency()) coeff = scipy.signal.iirnotch(freq / fn, q) - if dtype is None: - dtype = recording.get_dtype() - dtype = np.dtype(dtype) + dtype = fix_dtype(recording, dtype) # if uint --> unsupported if dtype.kind == "u": @@ -310,15 +341,12 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): "to specify a signed type (e.g. 'int16', 'float32')" ) - BasePreprocessor.__init__(self, recording, dtype=dtype) + FilterRecording.__init__( + self, recording, coeff=coeff, filter_mode="ba", margin_ms=margin_ms, dtype=dtype, **filter_kwargs + ) self.annotate(is_filtered=True) - - sf = recording.get_sampling_frequency() - margin = int(margin_ms * sf / 1000.0) - for parent_segment in recording._recording_segments: - self.add_recording_segment(FilterRecordingSegment(parent_segment, coeff, "ba", margin, dtype)) - self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str) + self._kwargs.update(filter_kwargs) # functions for API