diff --git a/elephant/signal_processing.py b/elephant/signal_processing.py index 9a901b41a..4153d97e1 100644 --- a/elephant/signal_processing.py +++ b/elephant/signal_processing.py @@ -24,6 +24,7 @@ import numpy as np import quantities as pq import scipy.signal +from typing import Union from elephant.utils import check_same_units @@ -38,7 +39,9 @@ ] -def zscore(signal, inplace=True): +def zscore(signal: Union[neo.AnalogSignal, list[neo.AnalogSignal]], + inplace: bool = True + ) -> Union[neo.AnalogSignal, list[neo.AnalogSignal]]: r""" Apply a z-score operation to one or several `neo.AnalogSignal` objects. @@ -72,7 +75,7 @@ def zscore(signal, inplace=True): Returns ------- - signal_ztransformed : neo.AnalogSignal or list of neo.AnalogSignal + out : neo.AnalogSignal or list of neo.AnalogSignal The output format matches the input format: for each input `neo.AnalogSignal`, a corresponding `neo.AnalogSignal` is returned, containing the z-transformed signal with dimensionless unit. @@ -195,8 +198,11 @@ def zscore(signal, inplace=True): return signal_ztransformed -def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False, - n_lags=None, scaleopt='unbiased'): +def cross_correlation_function(signal: neo.AnalogSignal, + channel_pairs: Union[list, np.ndarray], + hilbert_envelope: bool = False, + n_lags: Union[int, None] = None, + scaleopt: str = 'unbiased') -> neo.AnalogSignal: r""" Computes an estimator of the cross-correlation function :cite:`signal-Stoica2005`. @@ -268,7 +274,7 @@ def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False, Returns ------- - cross_corr : neo.AnalogSignal + out : neo.AnalogSignal Shape: `[2*n_lags+1, n]` Pairwise cross-correlation functions for channel pairs given by `channel_pairs`. If `hilbert_envelope` is True, the output is the @@ -387,8 +393,13 @@ def cross_correlation_function(signal, channel_pairs, hilbert_envelope=False, return cross_corr -def butter(signal, highpass_frequency=None, lowpass_frequency=None, order=4, - filter_function='filtfilt', sampling_frequency=1.0, axis=-1): +def butter(signal: Union[neo.AnalogSignal, pq.Quantity, np.ndarray], + highpass_frequency: Union[pq.Quantity, float, None] = None, + lowpass_frequency: Union[pq.Quantity, float, None] = None, + order: int = 4, + filter_function: str = 'filtfilt', + sampling_frequency: Union[pq.Quantity, float] = 1.0, + axis: int = -1) -> Union[neo.AnalogSignal, pq.Quantity, np.ndarray]: """ Butterworth filtering function for `neo.AnalogSignal`. @@ -446,7 +457,7 @@ def butter(signal, highpass_frequency=None, lowpass_frequency=None, order=4, Returns ------- - filtered_signal : neo.AnalogSignal or pq.Quantity or np.ndarray + out : neo.AnalogSignal or pq.Quantity or np.ndarray Filtered input data. The shape and type is identical to those of the input `signal`. @@ -558,8 +569,11 @@ def butter(signal, highpass_frequency=None, lowpass_frequency=None, order=4, return filtered_data -def wavelet_transform(signal, frequency, n_cycles=6.0, sampling_frequency=1.0, - zero_padding=True): +def wavelet_transform(signal: Union[neo.AnalogSignal, np.ndarray, list], + frequency: Union[float, list[float]], + n_cycles: float = 6.0, + sampling_frequency: Union[float, pq.Quantity] = 1.0, + zero_padding: bool = True) -> np.ndarray: r""" Compute the wavelet transform of a given signal with Morlet mother wavelet. The parametrization of the wavelet is based on @@ -600,7 +614,7 @@ def wavelet_transform(signal, frequency, n_cycles=6.0, sampling_frequency=1.0, Returns ------- - signal_wt : np.ndarray + out : np.ndarray Wavelet transform of the input data. When `frequency` was given as a list, the way how the wavelet transforms for different frequencies are returned depends on the input type: @@ -729,7 +743,8 @@ def _morlet_wavelet_ft(freq, n_cycles, fs, n): return signal_wt -def hilbert(signal, padding='nextpow'): +def hilbert(signal: neo.AnalogSignal, + padding: Union[str, int, None] = 'nextpow') -> neo.AnalogSignal: """ Apply a Hilbert transform to a `neo.AnalogSignal` object in order to obtain its (complex) analytic signal. @@ -760,7 +775,7 @@ def hilbert(signal, padding='nextpow'): Returns ------- - neo.AnalogSignal + out : neo.AnalogSignal Contains the complex analytic signal(s) corresponding to the input `signal`. The unit of the returned `neo.AnalogSignal` is dimensionless. @@ -837,7 +852,11 @@ def hilbert(signal, padding='nextpow'): return output / output.units -def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None): +def rauc(signal: neo.AnalogSignal, + baseline: Union[pq.Quantity, str, None] = None, + bin_duration: Union[pq.Quantity, None] = None, + t_start: Union[pq.Quantity, None] = None, + t_stop: Union[pq.Quantity, None] = None) -> Union[pq.Quantity, neo.AnalogSignal]: """ Calculate the rectified area under the curve (RAUC) for a `neo.AnalogSignal`. @@ -883,7 +902,7 @@ def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None): Returns ------- - pq.Quantity or neo.AnalogSignal + out : pq.Quantity or neo.AnalogSignal If the number of bins is 1, the returned object is a scalar or vector `pq.Quantity` containing a single RAUC value for each channel. Otherwise, the returned object is a `neo.AnalogSignal` containing the @@ -977,7 +996,7 @@ def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None): return rauc_sig -def derivative(signal): +def derivative(signal: neo.AnalogSignal) -> neo.AnalogSignal: """ Calculate the derivative of a `neo.AnalogSignal`. @@ -989,7 +1008,7 @@ def derivative(signal): Returns ------- - derivative_sig : neo.AnalogSignal + out : neo.AnalogSignal The returned object is a `neo.AnalogSignal` containing the differences between each successive sample value of the input signal divided by the sampling period. Times are centered between the successive samples