Skip to content

Commit 6663397

Browse files
committed
added ground truth tests for the models
1 parent b48b3b3 commit 6663397

1 file changed

Lines changed: 219 additions & 0 deletions

File tree

tests/test_models_ground_truth.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import numpy as np
2+
3+
from smstools.models import harmonicModel, hprModel, hpsModel, sprModel, spsModel, stochasticModel, utilFunctions
4+
from smstools.models import dftModel as DFT
5+
from smstools.models import sineModel as SM
6+
7+
8+
FS = 44100
9+
10+
11+
def _sine(freq, length, amp=0.8, phase=0.0, fs=FS):
12+
n = np.arange(length)
13+
return amp * np.sin(2 * np.pi * freq * n / fs + phase)
14+
15+
16+
def _harmonic_stack(f0, length, harmonics=6, fs=FS):
17+
n = np.arange(length)
18+
x = np.zeros(length)
19+
for k in range(1, harmonics + 1):
20+
x += (1.0 / k) * np.sin(2 * np.pi * (k * f0) * n / fs)
21+
return x
22+
23+
24+
def _snr_db(reference, estimate):
25+
error = reference - estimate
26+
num = np.sum(reference**2)
27+
den = np.sum(error**2) + np.finfo(float).eps
28+
return 10.0 * np.log10(num / den)
29+
30+
31+
def test_single_sine_peak_frequency_accuracy():
32+
freq_true = 445.3
33+
M = 2047
34+
N = 8192
35+
x = _sine(freq_true, M)
36+
w = np.hanning(M)
37+
38+
mX, pX = DFT.dftAnal(x, w, N)
39+
ploc = utilFunctions.peakDetection(mX, -120)
40+
iploc, ipmag, ipphase = utilFunctions.peakInterp(mX, pX, ploc)
41+
ipfreq = FS * iploc / float(N)
42+
43+
freq_est = ipfreq[np.argmax(ipmag)]
44+
assert abs(freq_est - freq_true) < 3.0
45+
46+
47+
def test_f0twm_recovers_ground_truth_from_harmonic_candidates():
48+
f0_true = 220.0
49+
pfreq = np.array([220.0, 440.0, 660.0, 880.0, 1100.0])
50+
pmag = np.array([0.0, -6.0, -9.5, -12.0, -14.0])
51+
52+
f0_est = utilFunctions.f0Twm(
53+
pfreq=pfreq,
54+
pmag=pmag,
55+
ef0max=5,
56+
minf0=80,
57+
maxf0=500,
58+
f0t=0,
59+
)
60+
61+
assert abs(f0_est - f0_true) < 1.0
62+
63+
64+
def test_chirp_tracking_has_increasing_frequency_trend():
65+
length = 8192
66+
n = np.arange(length)
67+
f_start = 300.0
68+
f_end = 1200.0
69+
k = (f_end - f_start) / (length - 1)
70+
phase = 2 * np.pi * (f_start * n / FS + 0.5 * k * (n**2) / FS)
71+
x = 0.8 * np.sin(phase)
72+
73+
tfreq, tmag, tphase = SM.sineModelAnal(
74+
x,
75+
fs=FS,
76+
w=np.hanning(1025),
77+
N=2048,
78+
H=128,
79+
t=-80,
80+
maxnSines=25,
81+
minSineDur=0.01,
82+
)
83+
84+
frame_main_freq = []
85+
for frame in range(tfreq.shape[0]):
86+
valid = np.where(tfreq[frame] > 0)[0]
87+
if valid.size == 0:
88+
continue
89+
main_idx = valid[np.argmax(tmag[frame, valid])]
90+
frame_main_freq.append(tfreq[frame, main_idx])
91+
92+
frame_main_freq = np.array(frame_main_freq)
93+
assert frame_main_freq.size > 5
94+
assert frame_main_freq[-1] > frame_main_freq[0]
95+
96+
97+
def test_spr_component_additivity():
98+
x = _harmonic_stack(220.0, length=4096, harmonics=6) + 0.02 * np.random.default_rng(0).standard_normal(4096)
99+
w = np.hanning(513)
100+
101+
y, ys, xr = sprModel.sprModel(x, fs=FS, w=w, N=1024, t=-80)
102+
103+
assert y.shape == x.shape
104+
assert ys.shape == x.shape
105+
assert xr.shape == x.shape
106+
assert np.allclose(y, ys + xr, atol=1e-10)
107+
108+
109+
def test_dft_roundtrip_meets_snr_threshold():
110+
M = 2047
111+
N = 8192
112+
x = _sine(440.0, length=M, amp=0.9) + 0.25 * _sine(880.0, length=M, amp=0.7)
113+
w = np.hanning(M)
114+
115+
mX, pX = DFT.dftAnal(x, w, N)
116+
y = DFT.dftSynth(mX, pX, M)
117+
x_reference = x * (w / np.sum(w))
118+
119+
assert y.shape == x.shape
120+
assert np.isfinite(y).all()
121+
assert _snr_db(x_reference, y) > 60.0
122+
123+
124+
def test_harmonic_detection_recovers_expected_harmonics():
125+
f0 = 220.0
126+
pfreq = np.array([220.0, 440.0, 660.0, 880.0, 1000.0])
127+
pmag = np.array([-3.0, -6.0, -9.0, -12.0, -20.0])
128+
pphase = np.zeros_like(pfreq)
129+
130+
hfreq, hmag, hphase = harmonicModel.harmonicDetection(
131+
pfreq=pfreq,
132+
pmag=pmag,
133+
pphase=pphase,
134+
f0=f0,
135+
nH=4,
136+
hfreqp=np.array([]),
137+
fs=FS,
138+
)
139+
140+
assert np.allclose(hfreq, np.array([220.0, 440.0, 660.0, 880.0]), atol=1.0)
141+
assert hmag.shape == hfreq.shape
142+
assert hphase.shape == hfreq.shape
143+
144+
145+
def test_stochastic_mel_hz_conversion_roundtrip():
146+
freqs = np.array([50.0, 220.0, 440.0, 1000.0, 5000.0])
147+
mels = stochasticModel.hertz_to_mel(freqs)
148+
recon = stochasticModel.mel_to_hetz(mels)
149+
150+
assert np.allclose(freqs, recon, rtol=1e-8, atol=1e-8)
151+
152+
153+
def test_stochastic_analysis_synthesis_produces_valid_signal():
154+
x = np.random.default_rng(42).standard_normal(4096)
155+
stoc_env = stochasticModel.stochasticModelAnal(x, H=128, N=512, stocf=0.5)
156+
y = stochasticModel.stochasticModelSynth(stoc_env, H=128, N=512)
157+
158+
assert stoc_env.ndim == 2
159+
assert y.ndim == 1
160+
assert np.isfinite(stoc_env).all()
161+
assert np.isfinite(y).all()
162+
assert np.std(y) > 0
163+
164+
165+
def test_hpr_component_additivity():
166+
x = _harmonic_stack(220.0, length=4096, harmonics=6)
167+
w = np.hanning(513)
168+
169+
y, yh, xr = hprModel.hprModel(
170+
x,
171+
fs=FS,
172+
w=w,
173+
N=1024,
174+
t=-80,
175+
nH=20,
176+
minf0=50,
177+
maxf0=500,
178+
f0et=5,
179+
)
180+
181+
assert y.shape == x.shape
182+
assert yh.shape == x.shape
183+
assert xr.shape == x.shape
184+
assert np.allclose(y, yh + xr, atol=1e-10)
185+
186+
187+
def test_sps_component_additivity():
188+
x = _harmonic_stack(220.0, length=4096, harmonics=6)
189+
w = np.hanning(513)
190+
191+
y, ys, yst = spsModel.spsModel(x, fs=FS, w=w, N=1024, t=-80, stocf=1)
192+
193+
assert y.shape == x.shape
194+
assert ys.shape == x.shape
195+
assert yst.shape == x.shape
196+
assert np.allclose(y, ys + yst, atol=1e-10)
197+
198+
199+
def test_hps_component_additivity():
200+
x = _harmonic_stack(220.0, length=4096, harmonics=6)
201+
w = np.hanning(513)
202+
203+
y, yh, yst = hpsModel.hpsModel(
204+
x,
205+
fs=FS,
206+
w=w,
207+
N=1024,
208+
t=-80,
209+
nH=20,
210+
minf0=50,
211+
maxf0=500,
212+
f0et=5,
213+
stocf=1,
214+
)
215+
216+
assert y.shape == x.shape
217+
assert yh.shape == x.shape
218+
assert yst.shape == x.shape
219+
assert np.allclose(y, yh + yst, atol=1e-10)

0 commit comments

Comments
 (0)