Skip to content
53 changes: 36 additions & 17 deletions elephant/signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
import quantities as pq
import scipy.signal
from typing import Union

from elephant.utils import check_same_units

Expand All @@ -38,7 +39,9 @@
]


def zscore(signal, inplace=True):
def zscore(signal: Union[neo.AnalogSignal, list[neo.AnalogSignal]],
inplace: bool = True
) -> Union[neo.AnalogSignal, list[neo.AnalogSignal]]:
r"""
Apply a z-score operation to one or several `neo.AnalogSignal` objects.

Expand Down Expand Up @@ -72,7 +75,7 @@ def zscore(signal, inplace=True):

Returns
-------
signal_ztransformed : neo.AnalogSignal or list of neo.AnalogSignal
out : neo.AnalogSignal or list of neo.AnalogSignal
The output format matches the input format: for each input
`neo.AnalogSignal`, a corresponding `neo.AnalogSignal` is returned,
containing the z-transformed signal with dimensionless unit.
Expand Down Expand Up @@ -195,8 +198,11 @@ def zscore(signal, inplace=True):
return signal_ztransformed


def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False,
n_lags=None, scaleopt='unbiased'):
def cross_correlation_function(signal: neo.AnalogSignal,
channel_pairs: Union[list, np.ndarray],
hilbert_envelope: bool = False,
n_lags: Union[int, None] = None,
scaleopt: str = 'unbiased') -> neo.AnalogSignal:
r"""
Computes an estimator of the cross-correlation function
:cite:`signal-Stoica2005`.
Expand Down Expand Up @@ -268,7 +274,7 @@ def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False,

Returns
-------
cross_corr : neo.AnalogSignal
out : neo.AnalogSignal
Shape: `[2*n_lags+1, n]`
Pairwise cross-correlation functions for channel pairs given by
`channel_pairs`. If `hilbert_envelope` is True, the output is the
Expand Down Expand Up @@ -387,8 +393,13 @@ def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False,
return cross_corr


def butter(signal, highpass_frequency=None, lowpass_frequency=None, order=4,
filter_function='filtfilt', sampling_frequency=1.0, axis=-1):
def butter(signal: Union[neo.AnalogSignal, pq.Quantity, np.ndarray],
highpass_frequency: Union[pq.Quantity, float, None] = None,
lowpass_frequency: Union[pq.Quantity, float, None] = None,
order: int = 4,
filter_function: str = 'filtfilt',
sampling_frequency: Union[pq.Quantity, float] = 1.0,
axis: int = -1) -> Union[neo.AnalogSignal, pq.Quantity, np.ndarray]:
"""
Butterworth filtering function for `neo.AnalogSignal`.

Expand Down Expand Up @@ -446,7 +457,7 @@ def butter(signal, highpass_frequency=None, lowpass_frequency=None, order=4,

Returns
-------
filtered_signal : neo.AnalogSignal or pq.Quantity or np.ndarray
out : neo.AnalogSignal or pq.Quantity or np.ndarray
Filtered input data. The shape and type is identical to those of the
input `signal`.

Expand Down Expand Up @@ -558,8 +569,11 @@ def butter(signal, highpass_frequency=None, lowpass_frequency=None, order=4,
return filtered_data


def wavelet_transform(signal, frequency, n_cycles=6.0, sampling_frequency=1.0,
zero_padding=True):
def wavelet_transform(signal: Union[neo.AnalogSignal, np.ndarray, list],
frequency: Union[float, list[float]],
n_cycles: float = 6.0,
sampling_frequency: Union[float, pq.Quantity] = 1.0,
zero_padding: bool = True) -> np.ndarray:
r"""
Compute the wavelet transform of a given signal with Morlet mother
wavelet. The parametrization of the wavelet is based on
Expand Down Expand Up @@ -600,7 +614,7 @@ def wavelet_transform(signal, frequency, n_cycles=6.0, sampling_frequency=1.0,

Returns
-------
signal_wt : np.ndarray
out : np.ndarray
Wavelet transform of the input data. When `frequency` was given as a
list, the way how the wavelet transforms for different frequencies are
returned depends on the input type:
Expand Down Expand Up @@ -729,7 +743,8 @@ def _morlet_wavelet_ft(freq, n_cycles, fs, n):
return signal_wt


def hilbert(signal, padding='nextpow'):
def hilbert(signal: neo.AnalogSignal,
padding: Union[str, int, None] = 'nextpow') -> neo.AnalogSignal:
"""
Apply a Hilbert transform to a `neo.AnalogSignal` object in order to
obtain its (complex) analytic signal.
Expand Down Expand Up @@ -760,7 +775,7 @@ def hilbert(signal, padding='nextpow'):

Returns
-------
neo.AnalogSignal
out : neo.AnalogSignal
Contains the complex analytic signal(s) corresponding to the input
`signal`. The unit of the returned `neo.AnalogSignal` is
dimensionless.
Expand Down Expand Up @@ -837,7 +852,11 @@ def hilbert(signal, padding='nextpow'):
return output / output.units


def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None):
def rauc(signal: neo.AnalogSignal,
baseline: Union[pq.Quantity, str, None] = None,
bin_duration: Union[pq.Quantity, None] = None,
t_start: Union[pq.Quantity, None] = None,
t_stop: Union[pq.Quantity, None] = None) -> Union[pq.Quantity, neo.AnalogSignal]:
"""
Calculate the rectified area under the curve (RAUC) for a
`neo.AnalogSignal`.
Expand Down Expand Up @@ -883,7 +902,7 @@ def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None):

Returns
-------
pq.Quantity or neo.AnalogSignal
out : pq.Quantity or neo.AnalogSignal
If the number of bins is 1, the returned object is a scalar or
vector `pq.Quantity` containing a single RAUC value for each channel.
Otherwise, the returned object is a `neo.AnalogSignal` containing the
Expand Down Expand Up @@ -977,7 +996,7 @@ def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None):
return rauc_sig


def derivative(signal):
def derivative(signal: neo.AnalogSignal) -> neo.AnalogSignal:
"""
Calculate the derivative of a `neo.AnalogSignal`.

Expand All @@ -989,7 +1008,7 @@ def derivative(signal):

Returns
-------
derivative_sig : neo.AnalogSignal
out : neo.AnalogSignal
The returned object is a `neo.AnalogSignal` containing the differences
between each successive sample value of the input signal divided by
the sampling period. Times are centered between the successive samples
Expand Down
Loading