11"""Checker functions for filtering."""
22
33from warnings import warn
4+ import os
5+ import json
46
57import numpy as np
68
@@ -89,8 +91,8 @@ def check_filter_definition(pass_type, f_range):
8991 return f_lo , f_hi
9092
9193
92- def check_filter_properties (filter_coefs , a_vals , fs , pass_type , f_range ,
93- transitions = ( - 20 , - 3 ), filt_type = None , verbose = True ):
94+ def check_filter_properties (filter_coefs , a_vals , fs , pass_type , f_range , transitions = ( - 20 , - 3 ),
95+ filt_type = None , verbose = True , save_properties = None ):
9496 """Check a filters properties, including pass band and transition band.
9597
9698 Parameters
@@ -117,8 +119,12 @@ def check_filter_properties(filter_coefs, a_vals, fs, pass_type, f_range,
117119 a tuple and is assumed to be (None, f_hi) for 'lowpass', and (f_lo, None) for 'highpass'.
118120 transitions : tuple of (float, float), optional, default: (-20, -3)
119121 Cutoffs, in dB, that define the transition band.
122+ filt_type : str, optional, {'FIR', 'IIR'}
123+ The type of filter being applied.
120124 verbose : bool, optional, default: True
121- Whether to print out transition and pass bands.
125+ Whether to print out filter properties.
126+ save_properties : str
127+ Path, including file name, to save filter properites to as a json.
122128
123129 Returns
124130 -------
@@ -138,8 +144,8 @@ def check_filter_properties(filter_coefs, a_vals, fs, pass_type, f_range,
138144 """
139145
140146 # Import utility functions inside function to avoid circular imports
141- from neurodsp .filt .utils import (compute_frequency_response ,
142- compute_pass_band , compute_transition_band )
147+ from neurodsp .filt .utils import (compute_frequency_response , compute_pass_band ,
148+ compute_transition_band , gen_filt_report , save_filt_report )
143149
144150 # Initialize variable to keep track if all checks pass
145151 passes = True
@@ -173,52 +179,16 @@ def check_filter_properties(filter_coefs, a_vals, fs, pass_type, f_range,
173179 warn ('Transition bandwidth is {:.1f} Hz. This is greater than the desired' \
174180 'pass/stop bandwidth of {:.1f} Hz' .format (transition_bw , pass_bw ))
175181
176- # Print out transition bandwidth and pass bandwidth to the user
182+ # Report filter properties
183+ if verbose or save_properties :
184+ filt_report = gen_filt_report (pass_type , filt_type , fs , f_db , db , pass_bw ,
185+ transition_bw , f_range , f_range_trans )
186+
177187 if verbose :
188+ print ('\n ' .join ('{} : {}' .format (key , value ) for key , value in filt_report .items ()))
178189
179- # Filter type (high-pass, low-pass, band-pass, band-stop, FIR, IIR)
180- print ('Pass Type: {pass_type}' .format (pass_type = pass_type ))
181-
182- # Cutoff frequency (including definition)
183- cutoff = round (np .min (f_range ) + (0.5 * transition_bw ), 3 )
184- print ('Cutoff (half-amplitude): {cutoff} Hz' .format (cutoff = cutoff ))
185-
186- # Filter order (or length)
187- print ('Filter order: {order}' .format (order = len (f_db )- 1 ))
188-
189- # Roll-off or transition bandwidth
190- print ('Transition bandwidth: {:.1f} Hz' .format (transition_bw ))
191- print ('Pass/stop bandwidth: {:.1f} Hz' .format (pass_bw ))
192-
193- # Passband ripple and stopband attenuation
194- pb_ripple = np .max (db [:np .where (f_db < f_range_trans [0 ])[0 ][- 1 ]])
195- sb_atten = np .max (db [np .where (f_db > f_range_trans [1 ])[0 ][0 ]:])
196- print ('Passband Ripple: {pb_ripple} db' .format (pb_ripple = pb_ripple ))
197- print ('Stopband Attenuation: {sb_atten} db' .format (sb_atten = sb_atten ))
198-
199- # Filter delay (zero-phase, linear-phase, non-linear phase)
200- if filt_type == 'FIR' and pass_type in ['bandstop' , 'lowpass' ]:
201- filt_class = 'linear-phase'
202- elif filt_type == 'FIR' and pass_type in ['bandpass' , 'highpass' ]:
203- filt_class = 'zero-phase'
204- elif filt_type == 'IIR' :
205- filt_class = 'non-linear-phase'
206- else :
207- filt_class = None
208-
209- if filt_type is not None :
210- print ('Filter Class: {filt_class}' .format (filt_class = filt_class ))
211-
212- if filt_class == 'linear-phase' :
213- print ('Group Delay: {delay}s' .format (delay = (len (f_db )- 1 ) / 2 * fs ))
214- elif filt_class == 'zero-phase' :
215- print ('Group Delay: 0s' )
216-
217- # Direction of computation (one-pass forward/reverse, or two-pass forward and reverse)
218- if filt_type == 'FIR' :
219- print ('Direction: one-pass reverse.' )
220- else :
221- print ('Direction: two-pass forward and reverse' )
190+ if save_properties is not None :
191+ save_filt_report (save_properties , filt_report )
222192
223193 return passes
224194
0 commit comments