Skip to content

Commit 488fba2

Browse files
committed
added support for sounds with different sampling rates
1 parent 881d277 commit 488fba2

8 files changed

Lines changed: 397 additions & 54 deletions

File tree

smstools/models/dftModel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def dftAnal(x, w, N):
5959
Analysis of a signal using the discrete Fourier transform
6060
x: input signal, w: analysis window, N: FFT size
6161
returns mX, pX: magnitude and phase spectrum
62+
63+
The analysis window is internally normalized by sum(w), so the resulting
64+
spectra correspond to x * (w / sum(w)).
6265
"""
6366

6467
if not (UF.isPower2(N)): # raise error if N not a power of two
@@ -96,6 +99,9 @@ def dftSynth(mX, pX, M):
9699
Synthesis of a signal using the discrete Fourier transform
97100
mX: magnitude spectrum, pX: phase spectrum, M: window size
98101
returns y: output signal
102+
103+
If mX/pX come from dftAnal(), the output corresponds to the normalized
104+
windowed signal used there (x * (w / sum(w))).
99105
"""
100106

101107
hN = mX.size # size of positive spectrum, it includes sample 0

smstools/models/harmonicModel.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ def f0Detection(x, fs, w, N, H, t, minf0, maxf0, f0et):
2424
if minf0 < 0: # raise exception if minf0 is smaller than 0
2525
raise ValueError("Minumum fundamental frequency (minf0) smaller than 0")
2626

27-
if maxf0 >= 10000: # raise exception if maxf0 is bigger than fs/2
28-
raise ValueError("Maximum fundamental frequency (maxf0) bigger than 10000Hz")
27+
if maxf0 >= fs / 2.0: # raise exception if maxf0 is bigger than Nyquist
28+
raise ValueError(
29+
"Maximum fundamental frequency (maxf0) bigger than Nyquist frequency"
30+
)
2931

3032
if H <= 0: # raise error if hop size 0 or negative
3133
raise ValueError("Hop size (H) smaller or equal to 0")
@@ -44,19 +46,31 @@ def f0Detection(x, fs, w, N, H, t, minf0, maxf0, f0et):
4446
f0 = [] # initialize f0 output
4547
f0t = 0 # initialize f0 track
4648
f0stable = 0 # initialize f0 stable
49+
f0candidate = 0 # initialize one-frame candidate for stability confirmation
4750
while pin < pend:
4851
x1 = x[pin - hM1 : pin + hM2] # select frame
4952
mX, pX = DFT.dftAnal(x1, w, N) # compute dft
5053
ploc = UF.peakDetection(mX, t) # detect peak locations
5154
iploc, ipmag, ipphase = UF.peakInterp(mX, pX, ploc) # refine peak values
5255
ipfreq = fs * iploc / N # convert locations to Hez
53-
f0t = UF.f0Twm(ipfreq, ipmag, f0et, minf0, maxf0, f0stable) # find f0
54-
if ((f0stable == 0) & (f0t > 0)) or (
55-
(f0stable > 0) & (np.abs(f0stable - f0t) < f0stable / 5.0)
56-
):
57-
f0stable = f0t # consider a stable f0 if it is close to the previous one
58-
else:
56+
f0t = UF.f0Twm(ipfreq, ipmag, f0et, minf0, maxf0, f0stable, fs=fs) # find f0
57+
if f0t <= 0:
5958
f0stable = 0
59+
f0candidate = 0
60+
elif f0stable > 0:
61+
if np.abs(f0stable - f0t) < f0stable / 5.0:
62+
f0stable = f0t
63+
else:
64+
f0stable = 0
65+
f0candidate = f0t
66+
else:
67+
if (f0candidate > 0) and (
68+
np.abs(f0candidate - f0t) < max(f0candidate, f0t) / 5.0
69+
):
70+
f0stable = f0t
71+
f0candidate = 0
72+
else:
73+
f0candidate = f0t
6074
f0 = np.append(f0, f0t) # add f0 to output array
6175
pin += H # advance sound pointer
6276
return f0
@@ -146,7 +160,7 @@ def harmonicModel(x, fs, w, N, t, nH, minf0, maxf0, f0et):
146160
ploc = UF.peakDetection(mX, t) # detect peak locations
147161
iploc, ipmag, ipphase = UF.peakInterp(mX, pX, ploc) # refine peak values
148162
ipfreq = fs * iploc / N
149-
f0t = UF.f0Twm(ipfreq, ipmag, f0et, minf0, maxf0, f0stable) # find f0
163+
f0t = UF.f0Twm(ipfreq, ipmag, f0et, minf0, maxf0, f0stable, fs=fs) # find f0
150164
if ((f0stable == 0) & (f0t > 0)) or (
151165
(f0stable > 0) & (np.abs(f0stable - f0t) < f0stable / 5.0)
152166
):
@@ -208,7 +222,7 @@ def harmonicModelAnal(
208222
ploc = UF.peakDetection(mX, t) # detect peak locations
209223
iploc, ipmag, ipphase = UF.peakInterp(mX, pX, ploc) # refine peak values
210224
ipfreq = fs * iploc / N # convert locations to Hz
211-
f0t = UF.f0Twm(ipfreq, ipmag, f0et, minf0, maxf0, f0stable) # find f0
225+
f0t = UF.f0Twm(ipfreq, ipmag, f0et, minf0, maxf0, f0stable, fs=fs) # find f0
212226
if ((f0stable == 0) & (f0t > 0)) or (
213227
(f0stable > 0) & (np.abs(f0stable - f0t) < f0stable / 5.0)
214228
):

smstools/models/hprModel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def hprModel(x, fs, w, N, t, nH, minf0, maxf0, f0et):
9191
ploc = UF.peakDetection(mX, t) # find peaks
9292
iploc, ipmag, ipphase = UF.peakInterp(mX, pX, ploc) # refine peak values
9393
ipfreq = fs * iploc / N # convert locations to Hz
94-
f0t = UF.f0Twm(ipfreq, ipmag, f0et, minf0, maxf0, f0stable) # find f0
94+
f0t = UF.f0Twm(ipfreq, ipmag, f0et, minf0, maxf0, f0stable, fs=fs) # find f0
9595
if ((f0stable == 0) & (f0t > 0)) or (
9696
(f0stable > 0) & (np.abs(f0stable - f0t) < f0stable / 5.0)
9797
):

smstools/models/hpsModel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def hpsModel(x, fs, w, N, t, nH, minf0, maxf0, f0et, stocf):
9595
ploc = UF.peakDetection(mX, t) # find peaks
9696
iploc, ipmag, ipphase = UF.peakInterp(mX, pX, ploc) # refine peak values
9797
ipfreq = fs * iploc / N # convert peak locations to Hz
98-
f0t = UF.f0Twm(ipfreq, ipmag, f0et, minf0, maxf0, f0stable) # find f0
98+
f0t = UF.f0Twm(ipfreq, ipmag, f0et, minf0, maxf0, f0stable, fs=fs) # find f0
9999
if ((f0stable == 0) & (f0t > 0)) or (
100100
(f0stable > 0) & (np.abs(f0stable - f0t) < f0stable / 5.0)
101101
):

smstools/models/utilFunctions.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import subprocess
44
import sys
5+
import warnings
56

67
import numpy as np
78
from scipy.fft import fft, fftshift, ifft
@@ -12,20 +13,12 @@
1213
try:
1314
from smstools.models.utilFunctions_C import utilFunctions_C as UF_C
1415
except ImportError:
15-
print("\n")
16-
print(
17-
"-------------------------------------------------------------------------------"
16+
UF_C = None
17+
warnings.warn(
18+
"Cython core functions were not imported. Falling back to pure-Python "
19+
"implementations; this may be slower. See README.md for build instructions.",
20+
RuntimeWarning,
1821
)
19-
print("Warning:")
20-
print("Cython modules for some of the core functions were not imported.")
21-
print("Please refer to the README.md file in the 'sms-tools' directory,")
22-
print("for the instructions to compile the cython modules.")
23-
print("Exiting the code!!")
24-
print(
25-
"-------------------------------------------------------------------------------"
26-
)
27-
print("\n")
28-
sys.exit(0)
2922

3023
winsound_imported = False
3124
if sys.platform == "win32":
@@ -61,6 +54,8 @@ def wavread(filename):
6154
Read a sound file and convert it to a normalized floating point array
6255
filename: name of file to read
6356
returns fs: sampling rate of file, x: floating point array
57+
58+
Note: this function accepts any sampling rate and returns it as fs.
6459
"""
6560

6661
if os.path.isfile(filename) == False: # raise error if wrong input file
@@ -71,9 +66,6 @@ def wavread(filename):
7166
if len(x.shape) != 1: # raise error if more than one channel
7267
raise ValueError("Audio file should be mono")
7368

74-
if fs != 44100: # raise error if more than one channel
75-
raise ValueError("Sampling rate of input sound should be 44100")
76-
7769
# scale down and convert audio into floating point number in range of -1 to 1
7870
x = np.float32(x) / norm_fact[x.dtype.name]
7971
return fs, x
@@ -195,7 +187,10 @@ def genSpecSines(ipfreq, ipmag, ipphase, N, fs):
195187
returns Y: generated complex spectrum of sines
196188
"""
197189

198-
Y = UF_C.genSpecSines(N * ipfreq / float(fs), ipmag, ipphase, N)
190+
if UF_C is not None:
191+
Y = UF_C.genSpecSines(N * ipfreq / float(fs), ipmag, ipphase, N)
192+
else:
193+
Y = genSpecSines_p(ipfreq, ipmag, ipphase, N, fs)
199194
return Y
200195

201196

@@ -225,16 +220,16 @@ def genSpecSines_p(ipfreq, ipmag, ipphase, N, fs):
225220
if b[m] < 0: # peak lobe crosses DC bin
226221
Y[-b[m]] += lmag[m] * np.exp(-1j * ipphase[i])
227222
elif b[m] > hN: # peak lobe croses Nyquist bin
228-
Y[b[m]] += lmag[m] * np.exp(-1j * ipphase[i])
223+
Y[2 * hN - b[m]] += lmag[m] * np.exp(-1j * ipphase[i])
229224
elif b[m] == 0 or b[m] == hN: # peak lobe in the limits of the spectrum
230225
Y[b[m]] += lmag[m] * np.exp(1j * ipphase[i]) + lmag[m] * np.exp(
231226
-1j * ipphase[i]
232227
)
233228
else: # peak lobe in positive freq. range
234229
Y[b[m]] += lmag[m] * np.exp(1j * ipphase[i])
235-
Y[hN + 1 :] = Y[
236-
hN - 1 : 0 : -1
237-
].conjugate() # fill the negative part of the spectrum
230+
Y[hN + 1 :] = Y[
231+
hN - 1 : 0 : -1
232+
].conjugate() # fill the negative part of the spectrum
238233
return Y
239234

240235

@@ -251,19 +246,19 @@ def sinewaveSynth(freqs, amp, H, fs):
251246
lastfreq = freqs[0] # initialize synthesis frequency
252247
y = np.array([]) # initialize output array
253248
for l in range(freqs.size): # iterate over all frames
254-
if (lastfreq == 0) & (freqs[l] == 0): # if 0 freq add zeros
249+
if (lastfreq == 0) and (freqs[l] == 0): # if 0 freq add zeros
255250
A = np.zeros(H)
256251
freq = np.zeros(H)
257-
elif (lastfreq == 0) & (freqs[l] > 0): # if starting freq ramp up the amplitude
252+
elif (lastfreq == 0) and (freqs[l] > 0): # if starting freq ramp up the amplitude
258253
A = np.arange(0, amp, amp / H)
259254
freq = np.ones(H) * freqs[l]
260-
elif (lastfreq > 0) & (freqs[l] > 0): # if freqs in boundaries use both
255+
elif (lastfreq > 0) and (freqs[l] > 0): # if freqs in boundaries use both
261256
A = np.ones(H) * amp
262257
if lastfreq == freqs[l]:
263258
freq = np.ones(H) * lastfreq
264259
else:
265260
freq = np.arange(lastfreq, freqs[l], (freqs[l] - lastfreq) / H)
266-
elif (lastfreq > 0) & (freqs[l] == 0): # if ending freq ramp down the amplitude
261+
elif (lastfreq > 0) and (freqs[l] == 0): # if ending freq ramp down the amplitude
267262
A = np.arange(amp, 0, -amp / H)
268263
freq = np.ones(H) * lastfreq
269264
phase = 2 * np.pi * freq * t + lastphase # generate phase values
@@ -303,22 +298,30 @@ def cleaningTrack(track, minTrackLength=3):
303298
return cleanTrack
304299

305300

306-
def f0Twm(pfreq, pmag, ef0max, minf0, maxf0, f0t=0):
301+
def f0Twm(pfreq, pmag, ef0max, minf0, maxf0, f0t=0, fs=None):
307302
"""
308303
Function that wraps the f0 detection function TWM, selecting the possible f0 candidates
309304
and calling the function TWM with them
310305
pfreq, pmag: peak frequencies and magnitudes,
311306
ef0max: maximum error allowed, minf0, maxf0: minimum and maximum f0
312307
f0t: f0 of previous frame if stable
308+
fs: optional sampling rate in Hz. If provided, maxf0 must be below fs/2.
313309
returns f0: fundamental frequency in Hz
314310
"""
315311
if minf0 < 0: # raise exception if minf0 is smaller than 0
316312
raise ValueError("Minimum fundamental frequency (minf0) smaller than 0")
317313

318-
if maxf0 >= 10000: # raise exception if maxf0 is bigger than 10000Hz
319-
raise ValueError("Maximum fundamental frequency (maxf0) bigger than 10000Hz")
314+
if maxf0 <= minf0:
315+
raise ValueError(
316+
"Maximum fundamental frequency (maxf0) must be bigger than minf0"
317+
)
318+
319+
if (fs is not None) and (maxf0 >= fs / 2.0):
320+
raise ValueError(
321+
"Maximum fundamental frequency (maxf0) bigger than Nyquist frequency"
322+
)
320323

321-
if (pfreq.size < 3) & (
324+
if (pfreq.size < 3) and (
322325
f0t == 0
323326
): # return 0 if less than 3 peaks and not previous f0
324327
return 0
@@ -348,10 +351,14 @@ def f0Twm(pfreq, pmag, ef0max, minf0, maxf0, f0t=0):
348351
if f0cf.size == 0: # return 0 if no peak candidates
349352
return 0
350353

351-
f0, f0error = UF_C.twm(
352-
pfreq, pmag, f0cf
353-
) # call the TWM function with peak candidates, cython version
354-
# f0, f0error = TWM_p(pfreq, pmag, f0cf) # call the TWM function with peak candidates, python version
354+
if UF_C is not None:
355+
f0, f0error = UF_C.twm(
356+
pfreq, pmag, f0cf
357+
) # call the TWM function with peak candidates, cython version
358+
else:
359+
f0, f0error = TWM_p(
360+
pfreq, pmag, f0cf
361+
) # call the TWM function with peak candidates, python version
355362

356363
if (f0 > 0) and (
357364
f0error < ef0max
@@ -376,18 +383,17 @@ def TWM_p(pfreq, pmag, f0c):
376383
rho = 0.33 # weighting of MP error
377384
Amax = max(pmag) # maximum peak magnitude
378385
maxnpeaks = 10 # maximum number of peaks used
379-
harmonic = np.matrix(f0c)
386+
harmonic = np.asarray(f0c, dtype=float)
380387
ErrorPM = np.zeros(harmonic.size) # initialize PM errors
381388
MaxNPM = min(maxnpeaks, pfreq.size)
382389
for i in range(0, MaxNPM): # predicted to measured mismatch error
383-
difmatrixPM = harmonic.T * np.ones(pfreq.size)
384-
difmatrixPM = abs(difmatrixPM - np.ones((harmonic.size, 1)) * pfreq)
390+
difmatrixPM = abs(harmonic[:, None] - pfreq[None, :])
385391
FreqDistance = np.amin(difmatrixPM, axis=1) # minimum along rows
386392
peakloc = np.argmin(difmatrixPM, axis=1)
387-
Ponddif = np.array(FreqDistance) * (np.array(harmonic.T) ** (-p))
393+
Ponddif = FreqDistance * (harmonic ** (-p))
388394
PeakMag = pmag[peakloc]
389395
MagFactor = 10 ** ((PeakMag - Amax) / 20)
390-
ErrorPM = ErrorPM + (Ponddif + MagFactor * (q * Ponddif - r)).T
396+
ErrorPM = ErrorPM + (Ponddif + MagFactor * (q * Ponddif - r))
391397
harmonic = harmonic + f0c
392398

393399
ErrorMP = np.zeros(harmonic.size) # initialize MP errors
@@ -401,7 +407,7 @@ def TWM_p(pfreq, pmag, f0c):
401407
MagFactor = 10 ** ((PeakMag - Amax) / 20)
402408
ErrorMP[i] = sum(MagFactor * (Ponddif + MagFactor * (q * Ponddif - r)))
403409

404-
Error = (ErrorPM[0] / MaxNPM) + (rho * ErrorMP / MaxNMP) # total error
410+
Error = (ErrorPM / MaxNPM) + (rho * ErrorMP / MaxNMP) # total error
405411
f0index = np.argmin(Error) # get the smallest error
406412
f0 = f0c[f0index] # f0 with the smallest error
407413

@@ -476,7 +482,7 @@ def stochasticResidualAnal(x, N, H, sfreq, smag, sphase, fs, stocf):
476482
Xr = X - Yh # subtract sines from original spectrum
477483
mXr = 20 * np.log10(abs(Xr[:hN])) # magnitude spectrum of residual
478484
mXrenv = resample(
479-
np.maximum(-200, mXr), mXr.size * stocf
485+
np.maximum(-200, mXr), int(mXr.size * stocf)
480486
) # decimate the mag spectrum
481487
if l == 0: # if first frame
482488
stocEnv = np.array([mXrenv])

tests/test_errors.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pytest
3+
from scipy.io.wavfile import write
34

45
from smstools.models import dftModel, stft, utilFunctions
56

@@ -38,11 +39,47 @@ def test_f0twm_rejects_negative_minf0():
3839
def test_f0twm_rejects_too_large_maxf0():
3940
with pytest.raises(
4041
ValueError,
41-
match=r"Maximum fundamental frequency \(maxf0\) bigger than 10000Hz",
42+
match=r"Maximum fundamental frequency \(maxf0\) bigger than Nyquist frequency",
4243
):
43-
utilFunctions.f0Twm(np.array([100.0]), np.array([0.0]), 1.0, 50.0, 10000.0)
44+
utilFunctions.f0Twm(
45+
np.array([100.0, 200.0, 300.0]),
46+
np.array([0.0, -3.0, -6.0]),
47+
1.0,
48+
50.0,
49+
22050.0,
50+
fs=44100,
51+
)
52+
53+
54+
def test_f0twm_rejects_maxf0_above_nyquist_at_48k():
55+
with pytest.raises(
56+
ValueError,
57+
match=r"Maximum fundamental frequency \(maxf0\) bigger than Nyquist frequency",
58+
):
59+
utilFunctions.f0Twm(
60+
np.array([100.0, 200.0, 300.0]),
61+
np.array([0.0, -3.0, -6.0]),
62+
1.0,
63+
50.0,
64+
24000.0,
65+
fs=48000,
66+
)
4467

4568

4669
def test_wavread_rejects_missing_file():
4770
with pytest.raises(ValueError, match="Input file is wrong"):
4871
utilFunctions.wavread("does_not_exist.wav")
72+
73+
74+
def test_wavread_accepts_non_44100_sampling_rate(tmp_path):
75+
fs_in = 48000
76+
n = np.arange(1024)
77+
x = (0.2 * np.sin(2 * np.pi * 440.0 * n / fs_in) * 32767).astype(np.int16)
78+
wav_path = tmp_path / "tone_48k.wav"
79+
write(wav_path, fs_in, x)
80+
81+
fs_out, y = utilFunctions.wavread(str(wav_path))
82+
83+
assert fs_out == fs_in
84+
assert y.ndim == 1
85+
assert y.dtype == np.float32

0 commit comments

Comments
 (0)