Skip to content

Commit 6a5695e

Browse files
committed
save_properties of a filt
1 parent 17593b7 commit 6a5695e

File tree

7 files changed

+187
-65
lines changed

7 files changed

+187
-65
lines changed

neurodsp/filt/checks.py

Lines changed: 19 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Checker functions for filtering."""
22

33
from warnings import warn
4+
import os
5+
import json
46

57
import numpy as np
68

@@ -89,8 +91,8 @@ def check_filter_definition(pass_type, f_range):
8991
return f_lo, f_hi
9092

9193

92-
def check_filter_properties(filter_coefs, a_vals, fs, pass_type, f_range,
93-
transitions=(-20, -3), filt_type=None, verbose=True):
94+
def check_filter_properties(filter_coefs, a_vals, fs, pass_type, f_range, transitions=(-20, -3),
95+
filt_type=None, verbose=True, save_properties=None):
9496
"""Check a filters properties, including pass band and transition band.
9597
9698
Parameters
@@ -117,8 +119,12 @@ def check_filter_properties(filter_coefs, a_vals, fs, pass_type, f_range,
117119
a tuple and is assumed to be (None, f_hi) for 'lowpass', and (f_lo, None) for 'highpass'.
118120
transitions : tuple of (float, float), optional, default: (-20, -3)
119121
Cutoffs, in dB, that define the transition band.
122+
filt_type : str, optional, {'FIR', 'IIR'}
123+
The type of filter being applied.
120124
verbose : bool, optional, default: True
121-
Whether to print out transition and pass bands.
125+
Whether to print out filter properties.
126+
save_properties : str
127+
Path, including file name, to save filter properites to as a json.
122128
123129
Returns
124130
-------
@@ -138,8 +144,8 @@ def check_filter_properties(filter_coefs, a_vals, fs, pass_type, f_range,
138144
"""
139145

140146
# Import utility functions inside function to avoid circular imports
141-
from neurodsp.filt.utils import (compute_frequency_response,
142-
compute_pass_band, compute_transition_band)
147+
from neurodsp.filt.utils import (compute_frequency_response, compute_pass_band,
148+
compute_transition_band, gen_filt_report, save_filt_report)
143149

144150
# Initialize variable to keep track if all checks pass
145151
passes = True
@@ -173,52 +179,16 @@ def check_filter_properties(filter_coefs, a_vals, fs, pass_type, f_range,
173179
warn('Transition bandwidth is {:.1f} Hz. This is greater than the desired'\
174180
'pass/stop bandwidth of {:.1f} Hz'.format(transition_bw, pass_bw))
175181

176-
# Print out transition bandwidth and pass bandwidth to the user
182+
# Report filter properties
183+
if verbose or save_properties:
184+
filt_report = gen_filt_report(pass_type, filt_type, fs, f_db, db, pass_bw,
185+
transition_bw, f_range, f_range_trans)
186+
177187
if verbose:
188+
print('\n'.join('{} : {}'.format(key, value) for key, value in filt_report.items()))
178189

179-
# Filter type (high-pass, low-pass, band-pass, band-stop, FIR, IIR)
180-
print('Pass Type: {pass_type}'.format(pass_type=pass_type))
181-
182-
# Cutoff frequency (including definition)
183-
cutoff = round(np.min(f_range) + (0.5 * transition_bw), 3)
184-
print('Cutoff (half-amplitude): {cutoff} Hz'.format(cutoff=cutoff))
185-
186-
# Filter order (or length)
187-
print('Filter order: {order}'.format(order=len(f_db)-1))
188-
189-
# Roll-off or transition bandwidth
190-
print('Transition bandwidth: {:.1f} Hz'.format(transition_bw))
191-
print('Pass/stop bandwidth: {:.1f} Hz'.format(pass_bw))
192-
193-
# Passband ripple and stopband attenuation
194-
pb_ripple = np.max(db[:np.where(f_db < f_range_trans[0])[0][-1]])
195-
sb_atten = np.max(db[np.where(f_db > f_range_trans[1])[0][0]:])
196-
print('Passband Ripple: {pb_ripple} db'.format(pb_ripple=pb_ripple))
197-
print('Stopband Attenuation: {sb_atten} db'.format(sb_atten=sb_atten))
198-
199-
# Filter delay (zero-phase, linear-phase, non-linear phase)
200-
if filt_type == 'FIR' and pass_type in ['bandstop', 'lowpass']:
201-
filt_class = 'linear-phase'
202-
elif filt_type == 'FIR' and pass_type in ['bandpass', 'highpass']:
203-
filt_class = 'zero-phase'
204-
elif filt_type == 'IIR':
205-
filt_class = 'non-linear-phase'
206-
else:
207-
filt_class = None
208-
209-
if filt_type is not None:
210-
print('Filter Class: {filt_class}'.format(filt_class=filt_class))
211-
212-
if filt_class == 'linear-phase':
213-
print('Group Delay: {delay}s'.format(delay=(len(f_db)-1) / 2 * fs))
214-
elif filt_class == 'zero-phase':
215-
print('Group Delay: 0s')
216-
217-
# Direction of computation (one-pass forward/reverse, or two-pass forward and reverse)
218-
if filt_type == 'FIR':
219-
print('Direction: one-pass reverse.')
220-
else:
221-
print('Direction: two-pass forward and reverse')
190+
if save_properties is not None:
191+
save_filt_report(save_properties, filt_report)
222192

223193
return passes
224194

neurodsp/filt/filter.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
###################################################################################################
99
###################################################################################################
1010

11-
def filter_signal(sig, fs, pass_type, f_range, filter_type='fir',
12-
n_cycles=3, n_seconds=None, remove_edges=True, butterworth_order=None,
13-
print_transitions=False, plot_properties=False, return_filter=False,
14-
verbose=False):
11+
def filter_signal(sig, fs, pass_type, f_range, filter_type='fir', n_cycles=3, n_seconds=None,
12+
remove_edges=True, butterworth_order=None, print_transitions=False,
13+
plot_properties=False, save_properties=None, return_filter=False, verbose=False):
1514
"""Apply a bandpass, bandstop, highpass, or lowpass filter to a neural signal.
1615
1716
Parameters
@@ -51,6 +50,8 @@ def filter_signal(sig, fs, pass_type, f_range, filter_type='fir',
5150
If True, print out the transition and pass bandwidths.
5251
plot_properties : bool, optional, default: False
5352
If True, plot the properties of the filter, including frequency response and/or kernel.
53+
save_properties : str, optional, default: None
54+
Path, including file name, to save filter properites to as a json.
5455
return_filter : bool, optional, default: False
5556
If True, return the filter coefficients.
5657
verbose : bool, optional, default: False
@@ -76,13 +77,13 @@ def filter_signal(sig, fs, pass_type, f_range, filter_type='fir',
7677

7778
if filter_type.lower() == 'fir':
7879
return filter_signal_fir(sig, fs, pass_type, f_range, n_cycles, n_seconds,
79-
remove_edges, print_transitions,
80-
plot_properties, return_filter, verbose=verbose)
80+
remove_edges, print_transitions, plot_properties,
81+
save_properties, return_filter, verbose)
8182
elif filter_type.lower() == 'iir':
8283
_iir_checks(n_seconds, butterworth_order, remove_edges)
8384
return filter_signal_iir(sig, fs, pass_type, f_range, butterworth_order,
84-
print_transitions, plot_properties,
85-
return_filter, verbose=verbose)
85+
print_transitions, plot_properties, save_properties,
86+
return_filter, verbose)
8687
else:
8788
raise ValueError('Filter type not understood.')
8889

neurodsp/filt/fir.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
###################################################################################################
1515

1616
def filter_signal_fir(sig, fs, pass_type, f_range, n_cycles=3, n_seconds=None, remove_edges=True,
17-
print_transitions=False, plot_properties=False, return_filter=False,
18-
verbose=False):
17+
print_transitions=False, plot_properties=False, save_properties=None,
18+
return_filter=False, verbose=False, file_path=None, file_name=None):
1919
"""Apply an FIR filter to a signal.
2020
2121
Parameters
@@ -47,6 +47,8 @@ def filter_signal_fir(sig, fs, pass_type, f_range, n_cycles=3, n_seconds=None, r
4747
If True, print out the transition and pass bandwidths.
4848
plot_properties : bool, optional, default: False
4949
If True, plot the properties of the filter, including frequency response and/or kernel.
50+
save_properties : str
51+
Path, including file name, to save filter properites to as a json.
5052
return_filter : bool, optional, default: False
5153
If True, return the filter coefficients of the FIR filter.
5254
verbose : bool, optional, default: False
@@ -82,7 +84,8 @@ def filter_signal_fir(sig, fs, pass_type, f_range, n_cycles=3, n_seconds=None, r
8284

8385
# Check filter properties: compute transition bandwidth & run checks
8486
check_filter_properties(filter_coefs, 1, fs, pass_type, f_range, filt_type="FIR",
85-
verbose=np.any([print_transitions, verbose]))
87+
verbose=np.any([print_transitions, verbose]),
88+
save_properties=save_properties)
8689

8790
# Remove any NaN on the edges of 'sig'
8891
sig, sig_nans = remove_nans(sig)

neurodsp/filt/iir.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
###################################################################################################
1313

1414
def filter_signal_iir(sig, fs, pass_type, f_range, butterworth_order, print_transitions=False,
15-
plot_properties=False, return_filter=False, verbose=False):
15+
plot_properties=False, save_properties=None, return_filter=False,
16+
verbose=False):
1617
"""Apply an IIR filter to a signal.
1718
1819
Parameters
@@ -40,6 +41,8 @@ def filter_signal_iir(sig, fs, pass_type, f_range, butterworth_order, print_tran
4041
If True, print out the transition and pass bandwidths.
4142
plot_properties : bool, optional, default: False
4243
If True, plot the properties of the filter, including frequency response and/or kernel.
44+
save_properties : str
45+
Path, including file name, to save filter properites to as a json.
4346
return_filter : bool, optional, default: False
4447
If True, return the second order series coefficients of the IIR filter.
4548
verbose : bool, optional, default: False
@@ -69,7 +72,8 @@ def filter_signal_iir(sig, fs, pass_type, f_range, butterworth_order, print_tran
6972

7073
# Check filter properties: compute transition bandwidth & run checks
7174
check_filter_properties(sos, None, fs, pass_type, f_range, filt_type="IIR",
72-
verbose=np.any([print_transitions, verbose]))
75+
verbose=np.any([print_transitions, verbose]),
76+
save_properties=save_properties)
7377

7478
# Remove any NaN on the edges of 'sig'
7579
sig, sig_nans = remove_nans(sig)

neurodsp/filt/utils.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""Utility functions for filtering."""
22

3+
import os
4+
import json
5+
36
import numpy as np
47
from scipy.signal import freqz, sosfreqz
58

@@ -253,3 +256,105 @@ def remove_filter_edges(sig, filt_len):
253256
sig[-n_rmv:] = np.nan
254257

255258
return sig
259+
260+
261+
def gen_filt_report(pass_type, filt_type, fs, f_db, db, pass_bw,
262+
transition_bw, f_range, f_range_trans):
263+
"""Create a filter report.
264+
265+
Parameters
266+
----------
267+
pass_type : {'bandpass', 'bandstop', 'lowpass', 'highpass'}
268+
Which type of filter was applied.
269+
filt_type : str, {'FIR', 'IIR'}
270+
The type of filter being applied.
271+
fs : float
272+
Sampling rate, in Hz.
273+
f_db : 1d array
274+
Frequency vector corresponding to attenuation decibels, in Hz.
275+
db : 1d array
276+
Degree of attenuation for each frequency specified in `f_db`, in dB.
277+
pass_bw : float
278+
The pass bandwidth of the filter.
279+
transition_band : float
280+
The transition bandwidth of the filter.
281+
f_range : tuple of (float, float) or float
282+
Cutoff frequency(ies) used for filter, specified as f_lo & f_hi.
283+
f_range_trans : tuple of (float, float)
284+
The lower and upper frequencies of the transition band.
285+
286+
Returns
287+
-------
288+
filt_report : dict
289+
A dicionary of filter parameter keys and corresponding values.
290+
"""
291+
filt_report = {}
292+
293+
# Filter type (high-pass, low-pass, band-pass, band-stop, FIR, IIR)
294+
filt_report['Pass Type'] = '{pass_type}'.format(pass_type=pass_type)
295+
296+
# Cutoff frequenc(ies) (including definition)
297+
filt_report['Cutoff (half-amplitude)'] = '{cutoff} Hz'.format(cutoff=f_range)
298+
299+
# Filter order (or length)
300+
filt_report['Filter order'] = '{order}'.format(order=len(f_db)-1)
301+
302+
# Roll-off or transition bandwidth
303+
filt_report['Transition bandwidth'] = '{:.1f} Hz'.format(transition_bw)
304+
filt_report['Pass/stop bandwidth'] = '{:.1f} Hz'.format(pass_bw)
305+
306+
# Passband ripple and stopband attenuation
307+
pb_ripple = np.max(db[:np.where(f_db < f_range_trans[0])[0][-1]])
308+
sb_atten = np.max(db[np.where(f_db > f_range_trans[1])[0][0]:])
309+
filt_report['Passband Ripple'] = '{pb_ripple} db'.format(pb_ripple=pb_ripple)
310+
filt_report['Stopband Attenuation'] = '{sb_atten} db'.format(sb_atten=sb_atten)
311+
312+
# Filter delay (zero-phase, linear-phase, non-linear phase)
313+
filt_report['Filter Type'] = filt_type
314+
315+
if filt_type == 'FIR' and pass_type in ['bandstop', 'lowpass']:
316+
317+
filt_report['Filter Class'] = '{filt_class}'.format(filt_class='linear-phase')
318+
filt_report['Group Delay'] = '{delay}s'.format(delay=(len(f_db)-1) / 2 * fs)
319+
320+
elif filt_type == 'FIR' and pass_type in ['bandpass', 'highpass']:
321+
322+
filt_report['Filter Class'] = '{filt_class}'.format(filt_class='zero-phase')
323+
filt_report['Group Delay'] = '0s'
324+
325+
elif filt_type == 'IIR':
326+
327+
# Group delay isn't reported for IIR since it varies from sample to sample
328+
filt_report['Filter Class'] = '{filt_class}'.format(filt_class='non-linear-phase')
329+
330+
# Direction of computation (one-pass forward/reverse, or two-pass forward and reverse)
331+
if filt_type == 'FIR':
332+
filt_report['Direction'] = 'one-pass reverse'
333+
else:
334+
filt_report['Direction'] = 'two-pass forward and reverse'
335+
336+
return filt_report
337+
338+
339+
def save_filt_report(save_properties, filt_report):
340+
"""Save filter properties as a json file.
341+
342+
Parameters
343+
----------
344+
save_properties : str
345+
Path, including file name, to save filter properites to as a json.
346+
filt_report : dict
347+
Contains filter report info.
348+
"""
349+
350+
# Ensure parents exists
351+
if not os.path.isdir(os.path.dirname(save_properties)):
352+
raise ValueError("Unable to save properties. Parent directory does not exist.")
353+
354+
# Enforce file extension
355+
if not save_properties.endswith('.json'):
356+
save_properties = save_properties + '.json'
357+
358+
# Save
359+
with open(save_properties, 'w') as file_path:
360+
json.dump(filt_report, file_path)

neurodsp/tests/filt/test_checks.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for filter check functions."""
22

3+
import tempfile
34
from pytest import raises
45

56
from neurodsp.tests.settings import FS
@@ -46,7 +47,7 @@ def test_check_filter_definition():
4647
def test_check_filter_properties():
4748

4849
filter_coefs = design_fir_filter(FS, 'bandpass', (8, 12))
49-
50+
5051
passes = check_filter_properties(filter_coefs, 1, FS, 'bandpass', (8, 12))
5152
assert passes is True
5253

@@ -58,6 +59,10 @@ def test_check_filter_properties():
5859
passes = check_filter_properties(filter_coefs, 1, FS, 'bandpass', (8, 12))
5960
assert passes is False
6061

62+
temp_path = tempfile.NamedTemporaryFile()
63+
check_filter_properties(filter_coefs, 1, FS, 'bandpass', (8, 12),
64+
verbose=True, save_properties=temp_path.name)
65+
temp_path.close()
6166

6267
def test_check_filter_length():
6368

neurodsp/tests/filt/test_utils.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Tests for filter utilities."""
22

3-
from pytest import raises
4-
from neurodsp.tests.settings import FS
3+
import tempfile
4+
from pytest import raises, mark, param
5+
6+
import numpy as np
57

8+
from neurodsp.tests.settings import FS
69
from neurodsp.filt.utils import *
710
from neurodsp.filt.fir import design_fir_filter, compute_filter_length
811

@@ -53,3 +56,34 @@ def test_remove_filter_edges():
5356
assert np.all(np.isnan(dropped_sig[:n_rmv]))
5457
assert np.all(np.isnan(dropped_sig[-n_rmv:]))
5558
assert np.all(~np.isnan(dropped_sig[n_rmv:-n_rmv]))
59+
60+
61+
@mark.parametrize("pass_type", ['bandpass', 'bandstop', 'lowpass', 'highpass'])
62+
@mark.parametrize("filt_type", ['IIR', 'FIR'])
63+
def test_gen_filt_report(pass_type, filt_type):
64+
65+
fs = 1000
66+
f_db = np.arange(0, 50)
67+
db = np.random.rand(50)
68+
pass_bw = 10
69+
transition_bw = 4
70+
f_range = (10, 40)
71+
f_range_trans = (40, 44)
72+
73+
report = gen_filt_report(pass_type, filt_type, fs, f_db, db, pass_bw,
74+
transition_bw, f_range, f_range_trans)
75+
76+
assert pass_type in report.values()
77+
assert filt_type in report.values()
78+
79+
80+
@mark.parametrize("dir_exists", [True, param(False, marks=mark.xfail)])
81+
def test_save_filt_report(dir_exists):
82+
83+
filt_report = {'Pass Type': 'bandpass', 'Cutoff (half-amplitude)': 50}
84+
temp_path = tempfile.NamedTemporaryFile()
85+
if not dir_exists:
86+
save_filt_report('/bad/path/', filt_report)
87+
else:
88+
save_filt_report(temp_path.name, filt_report)
89+
temp_path.close()

0 commit comments

Comments
 (0)