Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 51 additions & 23 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class FilterRecording(BasePreprocessor):
btype : "bandpass" | "highpass", default: "bandpass"
Type of the filter
margin_ms : float, default: 5.0
Margin in ms on border to avoid border effect
Margin in ms on border to avoid border effect.
coeff : array | None, default: None
Filter coefficients in the filter_mode form.
dtype : dtype or None, default: None
Expand Down Expand Up @@ -78,7 +78,8 @@ def __init__(
filter_order=5,
ftype="butter",
filter_mode="sos",
margin_ms=5.0,
margin_ms=5,
max_margin_s=5,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I like this concept of max margin...
We really should prevent users to use spikeinterface for lfp low frequency purpose.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to use it! For wideband signals, there is no way around. You have to extract LFP on chunks. Extending the margin will minimize the error and the max margin is a tradeoff between accuracy and performance.

add_reflect_padding=False,
coeff=None,
dtype=None,
Expand Down Expand Up @@ -138,6 +139,22 @@ def __init__(
direction=direction,
)

def adjust_margin_ms_for_highpass(self, freq_min, max_margin_s):
# compute margin as 3 times the period of the highpass cutoff
margin_ms = 3 * (1000.0 / freq_min)
# limit max margin
max_margin_ms = max_margin_s * 1000.0
if margin_ms > max_margin_ms:
margin_ms = max_margin_ms
return margin_ms

def adjust_margin_ms_for_notch(self, max_margin_s, q, f0):
margin_ms = (3 / np.pi) * (q / f0) * 1000.0
max_margin_ms = max_margin_s * 1000.0
if margin_ms < max_margin_ms:
margin_ms = max_margin_ms
return margin_ms


class FilterRecordingSegment(BasePreprocessorSegment):
def __init__(
Expand Down Expand Up @@ -217,8 +234,11 @@ class BandpassFilterRecording(FilterRecording):
The highpass cutoff frequency in Hz
freq_max : float
The lowpass cutoff frequency in Hz
margin_ms : float
Margin in ms on border to avoid border effect
margin_ms : float | str, default: "auto"
Margin in ms on border to avoid border effect.
If "auto", margin is computed as 3 times the filter highpass cutoff period.
max_margin_s : float, default: 5
Maximum margin in seconds when margin_ms is set to "auto".
dtype : dtype or None
The dtype of the returned traces. If None, the dtype of the parent recording is used
{}
Expand All @@ -229,7 +249,11 @@ class BandpassFilterRecording(FilterRecording):
The bandpass-filtered recording extractor object
"""

def __init__(self, recording, freq_min=300.0, freq_max=6000.0, margin_ms=5.0, dtype=None, **filter_kwargs):
def __init__(
self, recording, freq_min=300.0, freq_max=6000.0, margin_ms="auto", max_margin_s=5, dtype=None, **filter_kwargs
):
if margin_ms == "auto":
margin_ms = self.adjust_margin_ms_for_highpass(freq_min, max_margin_s)
FilterRecording.__init__(
self, recording, band=[freq_min, freq_max], margin_ms=margin_ms, dtype=dtype, **filter_kwargs
)
Expand All @@ -250,8 +274,11 @@ class HighpassFilterRecording(FilterRecording):
The recording extractor to be re-referenced
freq_min : float
The highpass cutoff frequency in Hz
margin_ms : float
Margin in ms on border to avoid border effect
margin_ms : float | str, default: "auto"
Margin in ms on border to avoid border effect.
If "auto", margin is computed as 3 times the filter highpass cutoff period.
max_margin_s : float, default: 5
Maximum margin in seconds when margin_ms is set to "auto".
dtype : dtype or None
The dtype of the returned traces. If None, the dtype of the parent recording is used
{}
Expand All @@ -262,7 +289,9 @@ class HighpassFilterRecording(FilterRecording):
The highpass-filtered recording extractor object
"""

def __init__(self, recording, freq_min=300.0, margin_ms=5.0, dtype=None, **filter_kwargs):
def __init__(self, recording, freq_min=300.0, margin_ms="auto", max_margin_s=5, dtype=None, **filter_kwargs):
if margin_ms == "auto":
margin_ms = self.adjust_margin_ms_for_highpass(freq_min, max_margin_s)
FilterRecording.__init__(
self, recording, band=freq_min, margin_ms=margin_ms, dtype=dtype, btype="highpass", **filter_kwargs
)
Expand All @@ -271,7 +300,7 @@ def __init__(self, recording, freq_min=300.0, margin_ms=5.0, dtype=None, **filte
self._kwargs.update(filter_kwargs)


class NotchFilterRecording(BasePreprocessor):
class NotchFilterRecording(FilterRecording):
"""
Parameters
----------
Expand All @@ -283,25 +312,27 @@ class NotchFilterRecording(BasePreprocessor):
The quality factor of the notch filter
dtype : None | dtype, default: None
dtype of recording. If None, will take from `recording`
margin_ms : float, default: 5.0
margin_ms : float | str, default: "auto"
Margin in ms on border to avoid border effect
max_margin_s : float, default: 5
Maximum margin in seconds when margin_ms is set to "auto".

Returns
-------
filter_recording : NotchFilterRecording
The notch-filtered recording extractor object
"""

def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):
# coeef is 'ba' type
fn = 0.5 * float(recording.get_sampling_frequency())
def __init__(self, recording, freq=3000, q=30, margin_ms="auto", max_margin_s=5, dtype=None, **filter_kwargs):
import scipy.signal

if margin_ms == "auto":
margin_ms = self.adjust_margin_ms_for_notch(max_margin_s, q, freq)

fn = 0.5 * float(recording.get_sampling_frequency())
coeff = scipy.signal.iirnotch(freq / fn, q)

if dtype is None:
dtype = recording.get_dtype()
dtype = np.dtype(dtype)
dtype = fix_dtype(recording, dtype)

# if uint --> unsupported
if dtype.kind == "u":
Expand All @@ -310,15 +341,12 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):
"to specify a signed type (e.g. 'int16', 'float32')"
)

BasePreprocessor.__init__(self, recording, dtype=dtype)
FilterRecording.__init__(
self, recording, coeff=coeff, filter_mode="ba", margin_ms=margin_ms, dtype=dtype, **filter_kwargs
)
self.annotate(is_filtered=True)

sf = recording.get_sampling_frequency()
margin = int(margin_ms * sf / 1000.0)
for parent_segment in recording._recording_segments:
self.add_recording_segment(FilterRecordingSegment(parent_segment, coeff, "ba", margin, dtype))

self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str)
self._kwargs.update(filter_kwargs)


# functions for API
Expand Down